mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
refactor(gpu): unchecked_match_value_or to backend
This commit is contained in:
committed by
Agnès Leroy
parent
184f40439e
commit
32b1a7ab1d
@@ -75,3 +75,55 @@ template <typename Torus> struct int_extend_radix_with_sign_msb_buffer {
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Torus> struct int_cast_to_unsigned_buffer {
|
||||
int_radix_params params;
|
||||
bool allocate_gpu_memory;
|
||||
|
||||
bool requires_full_propagate;
|
||||
bool requires_sign_extension;
|
||||
|
||||
int_fullprop_buffer<Torus> *prop_buffer;
|
||||
int_extend_radix_with_sign_msb_buffer<Torus> *extend_buffer;
|
||||
|
||||
int_cast_to_unsigned_buffer(CudaStreams streams, int_radix_params params,
|
||||
uint32_t num_input_blocks,
|
||||
uint32_t target_num_blocks, bool input_is_signed,
|
||||
bool requires_full_propagate,
|
||||
bool allocate_gpu_memory,
|
||||
uint64_t &size_tracker) {
|
||||
this->params = params;
|
||||
this->allocate_gpu_memory = allocate_gpu_memory;
|
||||
this->requires_full_propagate = requires_full_propagate;
|
||||
|
||||
this->prop_buffer = nullptr;
|
||||
this->extend_buffer = nullptr;
|
||||
|
||||
if (requires_full_propagate) {
|
||||
this->prop_buffer = new int_fullprop_buffer<Torus>(
|
||||
streams, params, allocate_gpu_memory, size_tracker);
|
||||
}
|
||||
|
||||
this->requires_sign_extension =
|
||||
(target_num_blocks > num_input_blocks) && input_is_signed;
|
||||
|
||||
if (this->requires_sign_extension) {
|
||||
uint32_t num_blocks_to_add = target_num_blocks - num_input_blocks;
|
||||
this->extend_buffer = new int_extend_radix_with_sign_msb_buffer<Torus>(
|
||||
streams, params, num_input_blocks, num_blocks_to_add,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
}
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
if (this->prop_buffer) {
|
||||
this->prop_buffer->release(streams);
|
||||
delete this->prop_buffer;
|
||||
}
|
||||
if (this->extend_buffer) {
|
||||
this->extend_buffer->release(streams);
|
||||
delete this->extend_buffer;
|
||||
}
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -569,6 +569,10 @@ void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input,
|
||||
CudaStreamsFFI streams);
|
||||
|
||||
void trim_radix_blocks_msb_64(CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input,
|
||||
CudaStreamsFFI streams);
|
||||
|
||||
uint64_t scratch_cuda_apply_noise_squashing(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
@@ -850,6 +854,46 @@ void cuda_unchecked_match_value_64(
|
||||
|
||||
void cleanup_cuda_unchecked_match_value_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
uint64_t scratch_cuda_cast_to_unsigned_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed,
|
||||
bool requires_full_propagate, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI *input, int8_t *mem_ptr,
|
||||
uint32_t target_num_blocks, bool input_is_signed,
|
||||
void *const *bsks, void *const *ksks);
|
||||
|
||||
void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
uint64_t scratch_cuda_unchecked_match_value_or_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_matches, uint32_t num_input_blocks,
|
||||
uint32_t num_match_packed_blocks, uint32_t num_final_blocks,
|
||||
uint32_t max_output_is_zero, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_unchecked_match_value_or_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in_ct,
|
||||
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
|
||||
const uint64_t *h_or_value, int8_t *mem, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cleanup_cuda_unchecked_match_value_or_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
} // extern C
|
||||
|
||||
#endif // CUDA_INTEGER_H
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#pragma once
|
||||
#include "cast.h"
|
||||
#include "integer/comparison.h"
|
||||
#include "integer/radix_ciphertext.cuh"
|
||||
#include "integer_utilities.h"
|
||||
@@ -593,3 +594,98 @@ template <typename Torus> struct int_unchecked_match_buffer {
|
||||
delete this->packed_selectors_ct;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Torus> struct int_unchecked_match_value_or_buffer {
|
||||
int_radix_params params;
|
||||
bool allocate_gpu_memory;
|
||||
|
||||
uint32_t num_matches;
|
||||
uint32_t num_input_blocks;
|
||||
uint32_t num_match_packed_blocks;
|
||||
uint32_t num_final_blocks;
|
||||
bool max_output_is_zero;
|
||||
|
||||
int_unchecked_match_buffer<Torus> *match_buffer;
|
||||
int_cmux_buffer<Torus> *cmux_buffer;
|
||||
|
||||
CudaRadixCiphertextFFI *tmp_match_result;
|
||||
CudaRadixCiphertextFFI *tmp_match_bool;
|
||||
CudaRadixCiphertextFFI *tmp_or_value;
|
||||
|
||||
Torus *d_or_value;
|
||||
|
||||
int_unchecked_match_value_or_buffer(
|
||||
CudaStreams streams, int_radix_params params, uint32_t num_matches,
|
||||
uint32_t num_input_blocks, uint32_t num_match_packed_blocks,
|
||||
uint32_t num_final_blocks, bool max_output_is_zero,
|
||||
bool allocate_gpu_memory, uint64_t &size_tracker) {
|
||||
this->params = params;
|
||||
this->allocate_gpu_memory = allocate_gpu_memory;
|
||||
this->num_matches = num_matches;
|
||||
this->num_input_blocks = num_input_blocks;
|
||||
this->num_match_packed_blocks = num_match_packed_blocks;
|
||||
this->num_final_blocks = num_final_blocks;
|
||||
this->max_output_is_zero = max_output_is_zero;
|
||||
|
||||
this->match_buffer = new int_unchecked_match_buffer<Torus>(
|
||||
streams, params, num_matches, num_input_blocks, num_match_packed_blocks,
|
||||
max_output_is_zero, allocate_gpu_memory, size_tracker);
|
||||
|
||||
this->cmux_buffer = new int_cmux_buffer<Torus>(
|
||||
streams, [](Torus x) -> Torus { return x == 1; }, params,
|
||||
num_final_blocks, allocate_gpu_memory, size_tracker);
|
||||
|
||||
this->tmp_match_result = new CudaRadixCiphertextFFI;
|
||||
this->tmp_match_bool = new CudaRadixCiphertextFFI;
|
||||
this->tmp_or_value = new CudaRadixCiphertextFFI;
|
||||
|
||||
this->d_or_value = (Torus *)cuda_malloc_with_size_tracking_async(
|
||||
num_final_blocks * sizeof(Torus), streams.stream(0),
|
||||
streams.gpu_index(0), size_tracker, allocate_gpu_memory);
|
||||
|
||||
if (!max_output_is_zero) {
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->tmp_match_result,
|
||||
num_final_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
}
|
||||
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->tmp_match_bool, 1,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), this->tmp_or_value,
|
||||
num_final_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
this->match_buffer->release(streams);
|
||||
delete this->match_buffer;
|
||||
|
||||
this->cmux_buffer->release(streams);
|
||||
delete this->cmux_buffer;
|
||||
|
||||
if (!max_output_is_zero) {
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->tmp_match_result,
|
||||
this->allocate_gpu_memory);
|
||||
}
|
||||
delete this->tmp_match_result;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->tmp_match_bool,
|
||||
this->allocate_gpu_memory);
|
||||
delete this->tmp_match_bool;
|
||||
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->tmp_or_value,
|
||||
this->allocate_gpu_memory);
|
||||
delete this->tmp_or_value;
|
||||
|
||||
cuda_drop_async(this->d_or_value, streams.stream(0), streams.gpu_index(0));
|
||||
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -18,6 +18,15 @@ void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output,
|
||||
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
|
||||
}
|
||||
|
||||
void trim_radix_blocks_msb_64(CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI const *input,
|
||||
CudaStreamsFFI streams) {
|
||||
|
||||
auto cuda_streams = CudaStreams(streams);
|
||||
host_trim_radix_blocks_msb<uint64_t>(output, input, cuda_streams);
|
||||
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
|
||||
@@ -64,3 +73,46 @@ void cleanup_cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_cast_to_unsigned_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed,
|
||||
bool requires_full_propagate, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
return scratch_cuda_cast_to_unsigned<uint64_t>(
|
||||
CudaStreams(streams), (int_cast_to_unsigned_buffer<uint64_t> **)mem_ptr,
|
||||
params, num_input_blocks, target_num_blocks, input_is_signed,
|
||||
requires_full_propagate, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
|
||||
CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI *input, int8_t *mem_ptr,
|
||||
uint32_t target_num_blocks, bool input_is_signed,
|
||||
void *const *bsks, void *const *ksks) {
|
||||
|
||||
host_cast_to_unsigned<uint64_t>(
|
||||
CudaStreams(streams), output, input,
|
||||
(int_cast_to_unsigned_buffer<uint64_t> *)mem_ptr, target_num_blocks,
|
||||
input_is_signed, bsks, (uint64_t **)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
int_cast_to_unsigned_buffer<uint64_t> *mem_ptr =
|
||||
(int_cast_to_unsigned_buffer<uint64_t> *)(*mem_ptr_void);
|
||||
|
||||
mem_ptr->release(CudaStreams(streams));
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
@@ -36,6 +36,23 @@ __host__ void host_trim_radix_blocks_lsb(CudaRadixCiphertextFFI *output,
|
||||
input->num_radix_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
host_trim_radix_blocks_msb(CudaRadixCiphertextFFI *output_radix,
|
||||
const CudaRadixCiphertextFFI *input_radix,
|
||||
CudaStreams streams) {
|
||||
|
||||
PANIC_IF_FALSE(input_radix->num_radix_blocks >=
|
||||
output_radix->num_radix_blocks,
|
||||
"Cuda error: input radix ciphertext has fewer blocks than "
|
||||
"required to keep");
|
||||
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), output_radix, 0,
|
||||
output_radix->num_radix_blocks, input_radix, 0,
|
||||
output_radix->num_radix_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ uint64_t scratch_extend_radix_with_sign_msb(
|
||||
CudaStreams streams, int_extend_radix_with_sign_msb_buffer<Torus> **mem_ptr,
|
||||
@@ -91,4 +108,56 @@ __host__ void host_extend_radix_with_sign_msb(
|
||||
POP_RANGE()
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
uint64_t scratch_cuda_cast_to_unsigned(
|
||||
CudaStreams streams, int_cast_to_unsigned_buffer<Torus> **mem_ptr,
|
||||
int_radix_params params, uint32_t num_input_blocks,
|
||||
uint32_t target_num_blocks, bool input_is_signed,
|
||||
bool requires_full_propagate, bool allocate_gpu_memory) {
|
||||
|
||||
uint64_t size_tracker = 0;
|
||||
*mem_ptr = new int_cast_to_unsigned_buffer<Torus>(
|
||||
streams, params, num_input_blocks, target_num_blocks, input_is_signed,
|
||||
requires_full_propagate, allocate_gpu_memory, size_tracker);
|
||||
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output,
|
||||
CudaRadixCiphertextFFI *input,
|
||||
int_cast_to_unsigned_buffer<Torus> *mem_ptr,
|
||||
uint32_t target_num_blocks, bool input_is_signed,
|
||||
void *const *bsks, Torus *const *ksks) {
|
||||
|
||||
uint32_t current_num_blocks = input->num_radix_blocks;
|
||||
|
||||
if (mem_ptr->requires_full_propagate) {
|
||||
host_full_propagate_inplace<Torus>(streams, input, mem_ptr->prop_buffer,
|
||||
ksks, bsks, current_num_blocks);
|
||||
}
|
||||
|
||||
if (target_num_blocks > current_num_blocks) {
|
||||
uint32_t num_blocks_to_add = target_num_blocks - current_num_blocks;
|
||||
|
||||
if (input_is_signed) {
|
||||
host_extend_radix_with_sign_msb<Torus>(
|
||||
streams, output, input, mem_ptr->extend_buffer, num_blocks_to_add,
|
||||
bsks, (Torus **)ksks);
|
||||
} else {
|
||||
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(output, input,
|
||||
streams);
|
||||
}
|
||||
|
||||
} else if (target_num_blocks < current_num_blocks) {
|
||||
host_trim_radix_blocks_msb<Torus>(output, input, streams);
|
||||
|
||||
} else {
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), output, 0, current_num_blocks,
|
||||
input, 0, current_num_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -173,3 +173,51 @@ void cleanup_cuda_unchecked_match_value_64(CudaStreamsFFI streams,
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_unchecked_match_value_or_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_matches, uint32_t num_input_blocks,
|
||||
uint32_t num_match_packed_blocks, uint32_t num_final_blocks,
|
||||
uint32_t max_output_is_zero, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
return scratch_cuda_unchecked_match_value_or<uint64_t>(
|
||||
CudaStreams(streams),
|
||||
(int_unchecked_match_value_or_buffer<uint64_t> **)mem_ptr, params,
|
||||
num_matches, num_input_blocks, num_match_packed_blocks, num_final_blocks,
|
||||
max_output_is_zero, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_unchecked_match_value_or_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in_ct,
|
||||
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
|
||||
const uint64_t *h_or_value, int8_t *mem, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
|
||||
host_unchecked_match_value_or<uint64_t>(
|
||||
CudaStreams(streams), lwe_array_out, lwe_array_in_ct, h_match_inputs,
|
||||
h_match_outputs, h_or_value,
|
||||
(int_unchecked_match_value_or_buffer<uint64_t> *)mem, bsks,
|
||||
(uint64_t *const *)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_unchecked_match_value_or_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
int_unchecked_match_value_or_buffer<uint64_t> *mem_ptr =
|
||||
(int_unchecked_match_value_or_buffer<uint64_t> *)(*mem_ptr_void);
|
||||
|
||||
mem_ptr->release(CudaStreams(streams));
|
||||
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "integer/cast.cuh"
|
||||
#include "integer/cmux.cuh"
|
||||
#include "integer/comparison.cuh"
|
||||
#include "integer/integer.cuh"
|
||||
#include "integer/radix_ciphertext.cuh"
|
||||
@@ -453,3 +455,46 @@ uint64_t scratch_cuda_unchecked_match_value(
|
||||
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
uint64_t scratch_cuda_unchecked_match_value_or(
|
||||
CudaStreams streams, int_unchecked_match_value_or_buffer<Torus> **mem_ptr,
|
||||
int_radix_params params, uint32_t num_matches, uint32_t num_input_blocks,
|
||||
uint32_t num_match_packed_blocks, uint32_t num_final_blocks,
|
||||
bool max_output_is_zero, bool allocate_gpu_memory) {
|
||||
|
||||
uint64_t size_tracker = 0;
|
||||
*mem_ptr = new int_unchecked_match_value_or_buffer<Torus>(
|
||||
streams, params, num_matches, num_input_blocks, num_match_packed_blocks,
|
||||
num_final_blocks, max_output_is_zero, allocate_gpu_memory, size_tracker);
|
||||
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void host_unchecked_match_value_or(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in_ct,
|
||||
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
|
||||
const uint64_t *h_or_value,
|
||||
int_unchecked_match_value_or_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
|
||||
host_unchecked_match_value<Torus>(streams, mem_ptr->tmp_match_result,
|
||||
mem_ptr->tmp_match_bool, lwe_array_in_ct,
|
||||
h_match_inputs, h_match_outputs,
|
||||
mem_ptr->match_buffer, bsks, ksks);
|
||||
|
||||
cuda_memcpy_async_to_gpu(mem_ptr->d_or_value, h_or_value,
|
||||
mem_ptr->num_final_blocks * sizeof(Torus),
|
||||
streams.stream(0), streams.gpu_index(0));
|
||||
|
||||
set_trivial_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), mem_ptr->tmp_or_value,
|
||||
mem_ptr->d_or_value, (Torus *)h_or_value, mem_ptr->num_final_blocks,
|
||||
mem_ptr->params.message_modulus, mem_ptr->params.carry_modulus);
|
||||
|
||||
host_cmux<Torus>(streams, lwe_array_out, mem_ptr->tmp_match_bool,
|
||||
mem_ptr->tmp_match_result, mem_ptr->tmp_or_value,
|
||||
mem_ptr->cmux_buffer, bsks, (Torus **)ksks);
|
||||
}
|
||||
|
||||
@@ -1270,6 +1270,13 @@ unsafe extern "C" {
|
||||
streams: CudaStreamsFFI,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn trim_radix_blocks_msb_64(
|
||||
output: *mut CudaRadixCiphertextFFI,
|
||||
input: *const CudaRadixCiphertextFFI,
|
||||
streams: CudaStreamsFFI,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_apply_noise_squashing(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -1872,6 +1879,89 @@ unsafe extern "C" {
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_cast_to_unsigned_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
num_input_blocks: u32,
|
||||
target_num_blocks: u32,
|
||||
input_is_signed: bool,
|
||||
requires_full_propagate: bool,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_cast_to_unsigned_64(
|
||||
streams: CudaStreamsFFI,
|
||||
output: *mut CudaRadixCiphertextFFI,
|
||||
input: *mut CudaRadixCiphertextFFI,
|
||||
mem_ptr: *mut i8,
|
||||
target_num_blocks: u32,
|
||||
input_is_signed: bool,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_cast_to_unsigned_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_unchecked_match_value_or_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
num_matches: u32,
|
||||
num_input_blocks: u32,
|
||||
num_match_packed_blocks: u32,
|
||||
num_final_blocks: u32,
|
||||
max_output_is_zero: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_unchecked_match_value_or_64(
|
||||
streams: CudaStreamsFFI,
|
||||
lwe_array_out: *mut CudaRadixCiphertextFFI,
|
||||
lwe_array_in_ct: *const CudaRadixCiphertextFFI,
|
||||
h_match_inputs: *const u64,
|
||||
h_match_outputs: *const u64,
|
||||
h_or_value: *const u64,
|
||||
mem: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_unchecked_match_value_or_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_integer_compress_radix_ciphertext_64(
|
||||
streams: CudaStreamsFFI,
|
||||
|
||||
@@ -1363,6 +1363,63 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the estimated memory usage (in bytes) required on the GPU to perform
|
||||
/// the `match_value_or` operation.
|
||||
///
|
||||
/// This is useful to check if the operation fits in the GPU memory before attempting execution.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::prelude::*;
|
||||
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16, MatchValues};
|
||||
///
|
||||
/// let config = ConfigBuilder::default().build();
|
||||
/// let (client_key, server_key) = generate_keys(config);
|
||||
/// set_server_key(server_key);
|
||||
///
|
||||
/// let a = FheUint16::encrypt(17u16, &client_key);
|
||||
///
|
||||
/// let match_values = MatchValues::new(vec![(0u16, 3u16), (17u16, 25u16)]).unwrap();
|
||||
///
|
||||
/// #[cfg(feature = "gpu")]
|
||||
/// {
|
||||
/// let size_bytes = a
|
||||
/// .get_match_value_or_size_on_gpu(&match_values, 55u16)
|
||||
/// .unwrap();
|
||||
/// println!("Memory required on GPU: {} bytes", size_bytes);
|
||||
/// assert!(size_bytes > 0);
|
||||
/// }
|
||||
/// ```
|
||||
#[cfg(feature = "gpu")]
|
||||
pub fn get_match_value_or_size_on_gpu<Clear>(
|
||||
&self,
|
||||
matches: &MatchValues<Clear>,
|
||||
or_value: Clear,
|
||||
) -> crate::Result<u64>
|
||||
where
|
||||
Clear: UnsignedInteger + DecomposableInto<u64> + CastInto<usize>,
|
||||
{
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(_) => Err(crate::Error::new(
|
||||
"This function is only available when using the CUDA backend".to_string(),
|
||||
)),
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let streams = &cuda_key.streams;
|
||||
let ct_on_gpu = self.ciphertext.on_gpu(streams);
|
||||
|
||||
let size = cuda_key.key.key.get_unchecked_match_value_or_size_on_gpu(
|
||||
&ct_on_gpu, matches, or_value, streams,
|
||||
);
|
||||
Ok(size)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation.")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Reverse the bit of the unsigned integer
|
||||
///
|
||||
/// # Example
|
||||
|
||||
@@ -653,3 +653,9 @@ fn test_compressed_cpk_encrypt_cast_compute_hl() {
|
||||
let clear: u64 = mul.decrypt(&client_key);
|
||||
assert_eq!(clear, (input_msg * multiplier) % modulus);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_match_value_or() {
|
||||
let client_key = setup_default_cpu();
|
||||
super::test_case_match_value_or(&client_key);
|
||||
}
|
||||
|
||||
@@ -960,3 +960,42 @@ fn test_gpu_get_match_value_size_on_gpu() {
|
||||
assert!(memory_size > 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_match_value_or_gpu() {
|
||||
let client_key = setup_classical_gpu();
|
||||
super::test_case_match_value_or(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_match_value_or_gpu_multibit() {
|
||||
let client_key = setup_multibit_gpu();
|
||||
super::test_case_match_value_or(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_get_match_value_or_size_on_gpu() {
|
||||
for setup_fn in GPU_SETUP_FN {
|
||||
let cks = setup_fn();
|
||||
let mut rng = rand::thread_rng();
|
||||
let clear_a = rng.gen::<u32>();
|
||||
let or_value = rng.gen::<u32>();
|
||||
|
||||
let mut a = FheUint32::try_encrypt(clear_a, &cks).unwrap();
|
||||
a.move_to_current_device();
|
||||
|
||||
let match_values = MatchValues::new(vec![
|
||||
(0u32, 10u32),
|
||||
(1u32, 20u32),
|
||||
(clear_a, 30u32),
|
||||
(u32::MAX, 40u32),
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let memory_size = a
|
||||
.get_match_value_or_size_on_gpu(&match_values, or_value)
|
||||
.unwrap();
|
||||
check_valid_cuda_malloc_assert_oom(memory_size, GpuIndex::new(0));
|
||||
assert!(memory_size > 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -839,3 +839,52 @@ fn test_case_match_value(cks: &ClientKey) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn test_case_match_value_or(cks: &ClientKey) {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for _ in 0..5 {
|
||||
let clear_in = rng.gen::<u8>();
|
||||
let ct = FheUint8::encrypt(clear_in, cks);
|
||||
let clear_or_value = rng.gen::<u8>();
|
||||
|
||||
let should_match = rng.gen_bool(0.5);
|
||||
|
||||
let mut map: HashMap<u8, u8> = HashMap::new();
|
||||
let mut pairs = Vec::new();
|
||||
|
||||
let expected_value = if should_match {
|
||||
let val = rng.gen::<u8>();
|
||||
map.insert(clear_in, val);
|
||||
pairs.push((clear_in, val));
|
||||
val
|
||||
} else {
|
||||
clear_or_value
|
||||
};
|
||||
|
||||
let num_entries = rng.gen_range(1..10);
|
||||
for _ in 0..num_entries {
|
||||
let mut k = rng.gen::<u8>();
|
||||
while !should_match && k == clear_in {
|
||||
k = rng.gen::<u8>();
|
||||
}
|
||||
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = map.entry(k) {
|
||||
let v = rng.gen::<u8>();
|
||||
e.insert(v);
|
||||
pairs.push((k, v));
|
||||
}
|
||||
}
|
||||
|
||||
let matches = MatchValues::new(pairs).unwrap();
|
||||
|
||||
let result: FheUint8 = ct.match_value_or(&matches, clear_or_value).unwrap();
|
||||
|
||||
let dec_result: u8 = result.decrypt(cks);
|
||||
|
||||
assert_eq!(
|
||||
dec_result, expected_value,
|
||||
"Mismatch on result value for input {clear_in}. Should match: {should_match}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,18 +71,6 @@ impl CudaRadixCiphertextInfo {
|
||||
new_block_info
|
||||
}
|
||||
|
||||
pub(crate) fn after_trim_radix_blocks_msb(&self, num_blocks: usize) -> Self {
|
||||
assert!(num_blocks > 0);
|
||||
|
||||
let mut new_block_info = Self {
|
||||
blocks: Vec::with_capacity(self.blocks.len().saturating_sub(num_blocks)),
|
||||
};
|
||||
new_block_info
|
||||
.blocks
|
||||
.extend(self.blocks[..num_blocks].iter().copied());
|
||||
new_block_info
|
||||
}
|
||||
|
||||
pub fn duplicate(&self) -> Self {
|
||||
Self {
|
||||
blocks: self
|
||||
|
||||
@@ -7504,6 +7504,34 @@ pub(crate) unsafe fn cuda_backend_trim_radix_blocks_lsb(
|
||||
update_noise_degree(output, &cuda_ffi_output);
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_trim_radix_blocks_msb(
|
||||
output: &mut CudaRadixCiphertext,
|
||||
input: &CudaRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
let mut input_degrees = input.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut input_noise_levels = input.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let mut output_degrees = output.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut output_noise_levels = output.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
|
||||
let mut cuda_ffi_output =
|
||||
prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels);
|
||||
|
||||
let cuda_ffi_input = prepare_cuda_radix_ffi(input, &mut input_degrees, &mut input_noise_levels);
|
||||
|
||||
trim_radix_blocks_msb_64(
|
||||
&raw mut cuda_ffi_output,
|
||||
&raw const cuda_ffi_input,
|
||||
streams.ffi(),
|
||||
);
|
||||
update_noise_degree(output, &cuda_ffi_output);
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
@@ -8518,6 +8546,11 @@ pub(crate) unsafe fn cuda_backend_compute_equality_selectors<T: UnsignedInteger,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_create_possible_results<
|
||||
T: UnsignedInteger,
|
||||
B: Numeric,
|
||||
@@ -8654,6 +8687,11 @@ pub(crate) unsafe fn cuda_backend_create_possible_results<
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_aggregate_one_hot_vector<
|
||||
T: UnsignedInteger,
|
||||
B: Numeric,
|
||||
@@ -8790,6 +8828,11 @@ pub(crate) unsafe fn cuda_backend_aggregate_one_hot_vector<
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_match_value<
|
||||
T: UnsignedInteger,
|
||||
B: Numeric,
|
||||
@@ -8945,6 +8988,11 @@ pub(crate) unsafe fn cuda_backend_unchecked_match_value<
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) fn cuda_backend_get_unchecked_match_value_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
glwe_dimension: GlweDimension,
|
||||
@@ -8999,3 +9047,276 @@ pub(crate) fn cuda_backend_get_unchecked_match_value_size_on_gpu(
|
||||
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) fn cuda_backend_get_unchecked_match_value_or_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
pbs_type: PBSType,
|
||||
num_matches: u32,
|
||||
num_input_blocks: u32,
|
||||
num_match_packed_blocks: u32,
|
||||
num_output_blocks: u32,
|
||||
max_output_is_zero: bool,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
) -> u64 {
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
let size_tracker = unsafe {
|
||||
scratch_cuda_unchecked_match_value_or_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_matches,
|
||||
num_input_blocks,
|
||||
num_match_packed_blocks,
|
||||
num_output_blocks,
|
||||
max_output_is_zero as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
false,
|
||||
noise_reduction_type as u32,
|
||||
)
|
||||
};
|
||||
|
||||
unsafe {
|
||||
cleanup_cuda_unchecked_match_value_or_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr))
|
||||
};
|
||||
|
||||
size_tracker
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_cast_to_unsigned<T: UnsignedInteger, B: Numeric>(
|
||||
streams: &CudaStreams,
|
||||
output: &mut CudaRadixCiphertext,
|
||||
input: &mut CudaRadixCiphertext,
|
||||
input_is_signed: bool,
|
||||
requires_full_propagate: bool,
|
||||
target_num_blocks: u32,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
) {
|
||||
let message_modulus = input.info.blocks.first().unwrap().message_modulus;
|
||||
let carry_modulus = input.info.blocks.first().unwrap().carry_modulus;
|
||||
let num_input_blocks = input.d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut input_degrees: Vec<u64> = input.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut input_noise_levels: Vec<u64> =
|
||||
input.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let mut cuda_ffi_input =
|
||||
prepare_cuda_radix_ffi(input, &mut input_degrees, &mut input_noise_levels);
|
||||
|
||||
let mut output_degrees: Vec<u64> = output.info.blocks.iter().map(|b| b.degree.0).collect();
|
||||
let mut output_noise_levels: Vec<u64> =
|
||||
output.info.blocks.iter().map(|b| b.noise_level.0).collect();
|
||||
let mut cuda_ffi_output =
|
||||
prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels);
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_cast_to_unsigned_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_input_blocks,
|
||||
target_num_blocks,
|
||||
input_is_signed,
|
||||
requires_full_propagate,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_cast_to_unsigned_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_output,
|
||||
&raw mut cuda_ffi_input,
|
||||
mem_ptr,
|
||||
target_num_blocks,
|
||||
input_is_signed,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_cast_to_unsigned_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
update_noise_degree(output, &cuda_ffi_output);
|
||||
if requires_full_propagate {
|
||||
update_noise_degree(input, &cuda_ffi_input);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_match_value_or<
|
||||
T: UnsignedInteger,
|
||||
B: Numeric,
|
||||
R: CudaIntegerRadixCiphertext,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
lwe_array_out: &mut R,
|
||||
lwe_array_in_ct: &CudaRadixCiphertext,
|
||||
h_match_inputs: &[u64],
|
||||
h_match_outputs: &[u64],
|
||||
h_or_value: &[u64],
|
||||
num_matches: u32,
|
||||
num_input_blocks: u32,
|
||||
num_match_packed_blocks: u32,
|
||||
num_final_blocks: u32,
|
||||
max_output_is_zero: bool,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
) {
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut ffi_out_degrees: Vec<u64> = lwe_array_out
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.get())
|
||||
.collect();
|
||||
let mut ffi_out_noise_levels: Vec<u64> = lwe_array_out
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
let mut ffi_out_struct = prepare_cuda_radix_ffi(
|
||||
lwe_array_out.as_ref(),
|
||||
&mut ffi_out_degrees,
|
||||
&mut ffi_out_noise_levels,
|
||||
);
|
||||
|
||||
let mut ffi_in_ct_degrees: Vec<u64> = lwe_array_in_ct
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.get())
|
||||
.collect();
|
||||
let mut ffi_in_ct_noise_levels: Vec<u64> = lwe_array_in_ct
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
let ffi_in_ct_struct = prepare_cuda_radix_ffi(
|
||||
lwe_array_in_ct,
|
||||
&mut ffi_in_ct_degrees,
|
||||
&mut ffi_in_ct_noise_levels,
|
||||
);
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_unchecked_match_value_or_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_matches,
|
||||
num_input_blocks,
|
||||
num_match_packed_blocks,
|
||||
num_final_blocks,
|
||||
max_output_is_zero as u32,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_unchecked_match_value_or_64(
|
||||
streams.ffi(),
|
||||
&raw mut ffi_out_struct,
|
||||
&raw const ffi_in_ct_struct,
|
||||
h_match_inputs.as_ptr(),
|
||||
h_match_outputs.as_ptr(),
|
||||
h_or_value.as_ptr(),
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_unchecked_match_value_or_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
update_noise_degree(lwe_array_out.as_mut(), &ffi_out_struct);
|
||||
}
|
||||
|
||||
@@ -17,11 +17,11 @@ use crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_apply_bivariate_lut, cuda_backend_apply_many_univariate_lut,
|
||||
cuda_backend_apply_univariate_lut, cuda_backend_compute_prefix_sum_hillis_steele,
|
||||
cuda_backend_extend_radix_with_sign_msb,
|
||||
cuda_backend_apply_univariate_lut, cuda_backend_cast_to_unsigned,
|
||||
cuda_backend_compute_prefix_sum_hillis_steele, cuda_backend_extend_radix_with_sign_msb,
|
||||
cuda_backend_extend_radix_with_trivial_zero_blocks_msb, cuda_backend_full_propagate_assign,
|
||||
cuda_backend_noise_squashing, cuda_backend_propagate_single_carry_assign,
|
||||
cuda_backend_trim_radix_blocks_lsb, CudaServerKey, PBSType,
|
||||
cuda_backend_trim_radix_blocks_lsb, cuda_backend_trim_radix_blocks_msb, CudaServerKey, PBSType,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::OutputFlag;
|
||||
use crate::shortint::ciphertext::{Degree, NoiseLevel};
|
||||
@@ -584,31 +584,19 @@ impl CudaServerKey {
|
||||
num_blocks: usize,
|
||||
streams: &CudaStreams,
|
||||
) -> T {
|
||||
let total_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0;
|
||||
|
||||
if num_blocks == 0 {
|
||||
return ct.duplicate(streams);
|
||||
}
|
||||
let new_num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 - num_blocks;
|
||||
let ciphertext_modulus = ct.as_ref().d_blocks.ciphertext_modulus();
|
||||
let lwe_size = ct.as_ref().d_blocks.lwe_dimension().to_lwe_size();
|
||||
let shift = new_num_blocks * lwe_size.0;
|
||||
let new_num_blocks = total_blocks - num_blocks;
|
||||
let mut trimmed_ct: T = self.create_trivial_zero_radix(new_num_blocks, streams);
|
||||
|
||||
let mut trimmed_ct_vec = CudaVec::new(new_num_blocks * lwe_size.0, streams, 0);
|
||||
unsafe {
|
||||
trimmed_ct_vec.copy_src_range_gpu_to_gpu_async(
|
||||
0..shift,
|
||||
&ct.as_ref().d_blocks.0.d_vec,
|
||||
streams,
|
||||
0,
|
||||
);
|
||||
streams.synchronize();
|
||||
cuda_backend_trim_radix_blocks_msb(trimmed_ct.as_mut(), ct.as_ref(), streams);
|
||||
}
|
||||
let trimmed_ct_list = CudaLweCiphertextList::from_cuda_vec(
|
||||
trimmed_ct_vec,
|
||||
LweCiphertextCount(new_num_blocks),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
let trimmed_ct_info = ct.as_ref().info.after_trim_radix_blocks_msb(new_num_blocks);
|
||||
T::from(CudaRadixCiphertext::new(trimmed_ct_list, trimmed_ct_info))
|
||||
|
||||
trimmed_ct
|
||||
}
|
||||
|
||||
pub(crate) fn generate_lookup_table<F>(&self, f: F) -> LookupTableOwned
|
||||
@@ -1339,48 +1327,71 @@ impl CudaServerKey {
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
if !source.block_carries_are_empty() {
|
||||
self.full_propagate_assign(&mut source, streams);
|
||||
}
|
||||
let current_num_blocks = source.as_ref().info.blocks.len();
|
||||
if T::IS_SIGNED {
|
||||
// Casting from signed to unsigned
|
||||
// We have to trim or sign extend first
|
||||
if target_num_blocks > current_num_blocks {
|
||||
let num_blocks_to_add = target_num_blocks - current_num_blocks;
|
||||
let signed_res: T =
|
||||
self.extend_radix_with_sign_msb(&source, num_blocks_to_add, streams);
|
||||
<CudaUnsignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
|
||||
signed_res.into_inner(),
|
||||
)
|
||||
} else {
|
||||
let num_blocks_to_remove = current_num_blocks - target_num_blocks;
|
||||
let signed_res = self.trim_radix_blocks_msb(&source, num_blocks_to_remove, streams);
|
||||
<CudaUnsignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
|
||||
signed_res.into_inner(),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Casting from unsigned to unsigned, this is just about trimming/extending with zeros
|
||||
if target_num_blocks > current_num_blocks {
|
||||
let num_blocks_to_add = target_num_blocks - current_num_blocks;
|
||||
let unsigned_res = self.extend_radix_with_trivial_zero_blocks_msb(
|
||||
&source,
|
||||
num_blocks_to_add,
|
||||
streams,
|
||||
);
|
||||
<CudaUnsignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
|
||||
unsigned_res.into_inner(),
|
||||
)
|
||||
} else {
|
||||
let num_blocks_to_remove = current_num_blocks - target_num_blocks;
|
||||
let unsigned_res =
|
||||
self.trim_radix_blocks_msb(&source, num_blocks_to_remove, streams);
|
||||
<CudaUnsignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
|
||||
unsigned_res.into_inner(),
|
||||
)
|
||||
let mut result: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(target_num_blocks, streams);
|
||||
|
||||
let requires_full_propagate = !source.block_carries_are_empty();
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_cast_to_unsigned(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
source.as_mut(),
|
||||
T::IS_SIGNED,
|
||||
requires_full_propagate,
|
||||
target_num_blocks as u32,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_cast_to_unsigned(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
source.as_mut(),
|
||||
T::IS_SIGNED,
|
||||
requires_full_propagate,
|
||||
target_num_blocks as u32,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Cast a `CudaUnsignedRadixCiphertext` or `CudaSignedRadixCiphertext` to a
|
||||
|
||||
@@ -494,6 +494,46 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// For match_value_or operation
|
||||
impl<'a, F>
|
||||
OpSequenceFunctionExecutor<(&'a RadixCiphertext, &'a MatchValues<u64>, u64), RadixCiphertext>
|
||||
for OpSequenceGpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&MatchValues<u64>,
|
||||
u64,
|
||||
&CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext,
|
||||
{
|
||||
fn setup(
|
||||
&mut self,
|
||||
cks: &RadixClientKey,
|
||||
sks: &CompressedServerKey,
|
||||
seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
|
||||
) {
|
||||
self.setup_from_gpu_keys(cks, sks, seeder);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a RadixCiphertext, &'a MatchValues<u64>, u64),
|
||||
) -> RadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, input.2, &context.streams);
|
||||
|
||||
d_res.to_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
OpSequenceFunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext),
|
||||
|
||||
@@ -5,9 +5,10 @@ use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::test_random_op_sequence::{
|
||||
random_op_sequence_test, BinaryOpExecutor, ComparisonOpExecutor, DivRemOpExecutor,
|
||||
Log2OpExecutor, MatchValueExecutor, OprfBoundedExecutor, OprfCustomRangeExecutor, OprfExecutor,
|
||||
OverflowingOpExecutor, ScalarBinaryOpExecutor, ScalarComparisonOpExecutor,
|
||||
ScalarDivRemOpExecutor, ScalarOverflowingOpExecutor, SelectOpExecutor, UnaryOpExecutor,
|
||||
Log2OpExecutor, MatchValueExecutor, MatchValueOrExecutor, OprfBoundedExecutor,
|
||||
OprfCustomRangeExecutor, OprfExecutor, OverflowingOpExecutor, ScalarBinaryOpExecutor,
|
||||
ScalarComparisonOpExecutor, ScalarDivRemOpExecutor, ScalarOverflowingOpExecutor,
|
||||
SelectOpExecutor, UnaryOpExecutor,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::{
|
||||
get_user_defined_seed, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN,
|
||||
@@ -54,6 +55,7 @@ pub(crate) fn random_op_sequence_test_init_gpu<P>(
|
||||
)],
|
||||
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
|
||||
match_value_ops: &mut [(MatchValueExecutor, String)],
|
||||
match_value_or_ops: &mut [(MatchValueOrExecutor, String)],
|
||||
oprf_ops: &mut [(OprfExecutor, String)],
|
||||
oprf_bounded_ops: &mut [(OprfBoundedExecutor, String)],
|
||||
oprf_custom_range_ops: &mut [(OprfCustomRangeExecutor, String)],
|
||||
@@ -88,6 +90,7 @@ where
|
||||
+ scalar_div_rem_op.len()
|
||||
+ log2_ops.len()
|
||||
+ match_value_ops.len()
|
||||
+ match_value_or_ops.len()
|
||||
+ oprf_ops.len()
|
||||
+ oprf_bounded_ops.len()
|
||||
+ oprf_custom_range_ops.len();
|
||||
@@ -142,6 +145,9 @@ where
|
||||
for x in match_value_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in match_value_or_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in oprf_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
@@ -317,6 +323,7 @@ where
|
||||
)];
|
||||
|
||||
let mut match_value_ops: Vec<(MatchValueExecutor, String)> = vec![];
|
||||
let mut match_value_or_ops: Vec<(MatchValueOrExecutor, String)> = vec![];
|
||||
|
||||
let mut oprf_ops: Vec<(OprfExecutor, String)> = vec![];
|
||||
let mut oprf_bounded_ops: Vec<(OprfBoundedExecutor, String)> = vec![];
|
||||
@@ -336,6 +343,7 @@ where
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
&mut match_value_ops,
|
||||
&mut match_value_or_ops,
|
||||
&mut oprf_ops,
|
||||
&mut oprf_bounded_ops,
|
||||
&mut oprf_custom_range_ops,
|
||||
@@ -357,6 +365,7 @@ where
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
&mut match_value_ops,
|
||||
&mut match_value_or_ops,
|
||||
&mut oprf_ops,
|
||||
&mut oprf_bounded_ops,
|
||||
&mut oprf_custom_range_ops,
|
||||
@@ -781,6 +790,14 @@ where
|
||||
let mut match_value_ops: Vec<(MatchValueExecutor, String)> =
|
||||
vec![(Box::new(match_value_executor), "match_value".to_string())];
|
||||
|
||||
// Match Value Or Executor
|
||||
let match_value_or_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::match_value_or);
|
||||
let mut match_value_or_ops: Vec<(MatchValueOrExecutor, String)> = vec![(
|
||||
Box::new(match_value_or_executor),
|
||||
"match_value_or".to_string(),
|
||||
)];
|
||||
|
||||
// OPRF Executors
|
||||
let oprf_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(
|
||||
&CudaServerKey::par_generate_oblivious_pseudo_random_unsigned_integer,
|
||||
@@ -816,6 +833,7 @@ where
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
&mut match_value_ops,
|
||||
&mut match_value_or_ops,
|
||||
&mut oprf_ops,
|
||||
&mut oprf_bounded_ops,
|
||||
&mut oprf_custom_range_ops,
|
||||
@@ -837,6 +855,7 @@ where
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
&mut match_value_ops,
|
||||
&mut match_value_or_ops,
|
||||
&mut oprf_ops,
|
||||
&mut oprf_bounded_ops,
|
||||
&mut oprf_custom_range_ops,
|
||||
|
||||
@@ -9,8 +9,9 @@ use crate::integer::gpu::server_key::radix::CudaRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_aggregate_one_hot_vector, cuda_backend_compute_equality_selectors,
|
||||
cuda_backend_create_possible_results, cuda_backend_get_unchecked_match_value_size_on_gpu,
|
||||
cuda_backend_unchecked_match_value, PBSType,
|
||||
cuda_backend_create_possible_results, cuda_backend_get_unchecked_match_value_or_size_on_gpu,
|
||||
cuda_backend_get_unchecked_match_value_size_on_gpu, cuda_backend_unchecked_match_value,
|
||||
cuda_backend_unchecked_match_value_or, PBSType,
|
||||
};
|
||||
pub use crate::integer::server_key::radix_parallel::MatchValues;
|
||||
use crate::prelude::CastInto;
|
||||
@@ -410,26 +411,229 @@ impl CudaServerKey {
|
||||
Clear: UnsignedInteger + DecomposableInto<u64> + CastInto<usize>,
|
||||
{
|
||||
if matches.get_values().is_empty() {
|
||||
let ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(
|
||||
or_value,
|
||||
self.num_blocks_to_represent_unsigned_value(or_value),
|
||||
streams,
|
||||
);
|
||||
let num_blocks = self.num_blocks_to_represent_unsigned_value(or_value);
|
||||
let ct: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_radix(or_value, num_blocks, streams);
|
||||
return ct;
|
||||
}
|
||||
let (result, selected) = self.unchecked_match_value(ct, matches, streams);
|
||||
|
||||
// The result must have as many block to represent either the result of the match or the
|
||||
// or_value
|
||||
let num_blocks_to_represent_or_value =
|
||||
self.num_blocks_to_represent_unsigned_value(or_value);
|
||||
let num_blocks = (result.as_ref().d_blocks.lwe_ciphertext_count().0)
|
||||
.max(num_blocks_to_represent_or_value);
|
||||
let or_value: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_radix(or_value, num_blocks, streams);
|
||||
let casted_result = self.cast_to_unsigned(result, num_blocks, streams);
|
||||
// Note, this could be slightly faster when we have scalar if then_else
|
||||
self.unchecked_if_then_else(&selected, &casted_result, &or_value, streams)
|
||||
let num_input_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
let num_bits_in_message = self.message_modulus.0.ilog2();
|
||||
|
||||
let h_match_inputs: Vec<u64> = matches
|
||||
.get_values()
|
||||
.par_iter()
|
||||
.map(|(input, _output)| *input)
|
||||
.flat_map(|input_value| {
|
||||
BlockDecomposer::new(input_value, num_bits_in_message)
|
||||
.take(num_input_blocks as usize)
|
||||
.map(|block_value| block_value.cast_into())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let max_output_value_match = matches
|
||||
.get_values()
|
||||
.iter()
|
||||
.copied()
|
||||
.max_by(|(_, outputl), (_, outputr)| outputl.cmp(outputr))
|
||||
.expect("luts is not empty at this point")
|
||||
.1;
|
||||
|
||||
let num_blocks_match = self.num_blocks_to_represent_unsigned_value(max_output_value_match);
|
||||
let num_blocks_or = self.num_blocks_to_represent_unsigned_value(or_value);
|
||||
let final_num_blocks = num_blocks_match.max(num_blocks_or);
|
||||
|
||||
let num_match_packed_blocks = num_blocks_match.div_ceil(2) as u32;
|
||||
|
||||
let h_match_outputs: Vec<u64> = matches
|
||||
.get_values()
|
||||
.par_iter()
|
||||
.map(|(_input, output)| *output)
|
||||
.flat_map(|output_value| {
|
||||
BlockDecomposer::new(output_value, 2 * num_bits_in_message)
|
||||
.take(num_match_packed_blocks as usize)
|
||||
.map(|block_value| block_value.cast_into())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let h_or_value: Vec<u64> = BlockDecomposer::new(or_value, num_bits_in_message)
|
||||
.take(final_num_blocks)
|
||||
.map(|block_value| block_value.cast_into())
|
||||
.collect();
|
||||
|
||||
let mut result: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(final_num_blocks, streams);
|
||||
|
||||
let max_output_is_zero = max_output_value_match == Clear::ZERO;
|
||||
let num_matches = matches.get_values().len() as u32;
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_match_value_or(
|
||||
streams,
|
||||
&mut result,
|
||||
ct.as_ref(),
|
||||
&h_match_inputs,
|
||||
&h_match_outputs,
|
||||
&h_or_value,
|
||||
num_matches,
|
||||
num_input_blocks,
|
||||
num_match_packed_blocks,
|
||||
final_num_blocks as u32,
|
||||
max_output_is_zero,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_match_value_or(
|
||||
streams,
|
||||
&mut result,
|
||||
ct.as_ref(),
|
||||
&h_match_inputs,
|
||||
&h_match_outputs,
|
||||
&h_or_value,
|
||||
num_matches,
|
||||
num_input_blocks,
|
||||
num_match_packed_blocks,
|
||||
final_num_blocks as u32,
|
||||
max_output_is_zero,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn get_unchecked_match_value_or_size_on_gpu<Clear>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
matches: &MatchValues<Clear>,
|
||||
or_value: Clear,
|
||||
streams: &CudaStreams,
|
||||
) -> u64
|
||||
where
|
||||
Clear: UnsignedInteger + DecomposableInto<u64> + CastInto<usize>,
|
||||
{
|
||||
if matches.get_values().is_empty() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let num_input_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
|
||||
let max_output_value_match = matches
|
||||
.get_values()
|
||||
.iter()
|
||||
.copied()
|
||||
.max_by(|(_, outputl), (_, outputr)| outputl.cmp(outputr))
|
||||
.expect("luts is not empty at this point")
|
||||
.1;
|
||||
|
||||
let num_blocks_match = self.num_blocks_to_represent_unsigned_value(max_output_value_match);
|
||||
let num_blocks_or = self.num_blocks_to_represent_unsigned_value(or_value);
|
||||
let final_num_blocks = num_blocks_match.max(num_blocks_or);
|
||||
|
||||
let num_match_packed_blocks = num_blocks_match.div_ceil(2) as u32;
|
||||
|
||||
let max_output_is_zero = max_output_value_match == Clear::ZERO;
|
||||
let num_matches = matches.get_values().len() as u32;
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_get_unchecked_match_value_or_size_on_gpu(
|
||||
streams,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
num_matches,
|
||||
num_input_blocks,
|
||||
num_match_packed_blocks,
|
||||
final_num_blocks as u32,
|
||||
max_output_is_zero,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
)
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_get_unchecked_match_value_or_size_on_gpu(
|
||||
streams,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
num_matches,
|
||||
num_input_blocks,
|
||||
num_match_packed_blocks,
|
||||
final_num_blocks as u32,
|
||||
max_output_is_zero,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `match` an input value to an output value
|
||||
|
||||
@@ -267,9 +267,7 @@ impl<
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RandomOpSequenceDataGenerator<u64, RadixCiphertext> {
|
||||
#[allow(clippy::manual_is_multiple_of)]
|
||||
pub(crate) fn gen_match_values(&mut self, key_to_match: u64) -> (MatchValues<u64>, u64, bool) {
|
||||
let mut pairings = Vec::new();
|
||||
@@ -297,6 +295,19 @@ impl RandomOpSequenceDataGenerator<u64, RadixCiphertext> {
|
||||
does_match,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn gen_match_values_or(
|
||||
&mut self,
|
||||
key_to_match: u64,
|
||||
) -> (MatchValues<u64>, u64, u64) {
|
||||
let (mv, match_val, does_match) = self.gen_match_values(key_to_match);
|
||||
|
||||
let or_value = self.deterministic_seeder.seed().0 as u64;
|
||||
|
||||
let expected = if does_match { match_val } else { or_value };
|
||||
|
||||
(mv, or_value, expected)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
||||
@@ -79,6 +79,13 @@ pub(crate) type MatchValueExecutor = Box<
|
||||
>,
|
||||
>;
|
||||
|
||||
pub(crate) type MatchValueOrExecutor = Box<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a MatchValues<u64>, u64),
|
||||
RadixCiphertext,
|
||||
>,
|
||||
>;
|
||||
|
||||
pub(crate) type OprfExecutor =
|
||||
Box<dyn for<'a> OpSequenceFunctionExecutor<(Seed, u64), RadixCiphertext>>;
|
||||
|
||||
@@ -474,6 +481,15 @@ where
|
||||
let mut match_value_ops: Vec<(MatchValueExecutor, String)> =
|
||||
vec![(Box::new(match_value_executor), "match_value".to_string())];
|
||||
|
||||
// Match Values Or Executor
|
||||
let match_value_or_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::match_value_or_parallelized);
|
||||
|
||||
let mut match_value_or_ops: Vec<(MatchValueOrExecutor, String)> = vec![(
|
||||
Box::new(match_value_or_executor),
|
||||
"match_value_or".to_string(),
|
||||
)];
|
||||
|
||||
// OPRF Executors
|
||||
let oprf_executor = OpSequenceCpuFunctionExecutor::new(
|
||||
&ServerKey::par_generate_oblivious_pseudo_random_unsigned_integer,
|
||||
@@ -514,6 +530,7 @@ where
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
&mut match_value_ops,
|
||||
&mut match_value_or_ops,
|
||||
&mut oprf_ops,
|
||||
&mut oprf_bounded_ops,
|
||||
&mut oprf_custom_range_ops,
|
||||
@@ -535,6 +552,7 @@ where
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
&mut match_value_ops,
|
||||
&mut match_value_or_ops,
|
||||
&mut oprf_ops,
|
||||
&mut oprf_bounded_ops,
|
||||
&mut oprf_custom_range_ops,
|
||||
@@ -572,6 +590,7 @@ pub(crate) fn random_op_sequence_test_init_cpu<P>(
|
||||
)],
|
||||
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
|
||||
match_value_ops: &mut [(MatchValueExecutor, String)],
|
||||
match_value_or_ops: &mut [(MatchValueOrExecutor, String)],
|
||||
oprf_ops: &mut [(OprfExecutor, String)],
|
||||
oprf_bounded_ops: &mut [(OprfBoundedExecutor, String)],
|
||||
oprf_custom_range_ops: &mut [(OprfCustomRangeExecutor, String)],
|
||||
@@ -602,6 +621,7 @@ where
|
||||
+ scalar_div_rem_op.len()
|
||||
+ log2_ops.len()
|
||||
+ match_value_ops.len()
|
||||
+ match_value_or_ops.len()
|
||||
+ oprf_ops.len()
|
||||
+ oprf_bounded_ops.len()
|
||||
+ oprf_custom_range_ops.len();
|
||||
@@ -661,6 +681,9 @@ where
|
||||
for x in match_value_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in match_value_or_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in oprf_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
@@ -706,6 +729,7 @@ pub(crate) fn random_op_sequence_test(
|
||||
)],
|
||||
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
|
||||
match_value_ops: &mut [(MatchValueExecutor, String)],
|
||||
match_value_or_ops: &mut [(MatchValueOrExecutor, String)],
|
||||
oprf_ops: &mut [(OprfExecutor, String)],
|
||||
oprf_bounded_ops: &mut [(OprfBoundedExecutor, String)],
|
||||
oprf_custom_range_ops: &mut [(OprfCustomRangeExecutor, String)],
|
||||
@@ -729,7 +753,10 @@ pub(crate) fn random_op_sequence_test(
|
||||
div_rem_op_range.end..div_rem_op_range.end + scalar_div_rem_op.len();
|
||||
let log2_ops_range = scalar_div_rem_op_range.end..scalar_div_rem_op_range.end + log2_ops.len();
|
||||
let match_value_ops_range = log2_ops_range.end..log2_ops_range.end + match_value_ops.len();
|
||||
let oprf_ops_range = match_value_ops_range.end..match_value_ops_range.end + oprf_ops.len();
|
||||
let match_value_or_ops_range =
|
||||
match_value_ops_range.end..match_value_ops_range.end + match_value_or_ops.len();
|
||||
let oprf_ops_range =
|
||||
match_value_or_ops_range.end..match_value_or_ops_range.end + oprf_ops.len();
|
||||
let oprf_bounded_ops_range = oprf_ops_range.end..oprf_ops_range.end + oprf_bounded_ops.len();
|
||||
let oprf_custom_range_ops_range =
|
||||
oprf_bounded_ops_range.end..oprf_bounded_ops_range.end + oprf_custom_range_ops.len();
|
||||
@@ -1149,6 +1176,35 @@ pub(crate) fn random_op_sequence_test(
|
||||
operand.p,
|
||||
operand.p,
|
||||
);
|
||||
} else if match_value_or_ops_range.contains(&i) {
|
||||
let index = i - match_value_or_ops_range.start;
|
||||
let (match_value_or_executor, fn_name) = &mut match_value_or_ops[index];
|
||||
let operand = datagen.gen_op_single_operand(idx, fn_name);
|
||||
|
||||
let (match_values, or_value, expected_value) = datagen.gen_match_values_or(operand.p);
|
||||
|
||||
println!("{idx}: MatchValuesOr generated. Expected: {expected_value}");
|
||||
|
||||
let res = match_value_or_executor.execute((&operand.c, &match_values, or_value));
|
||||
// Determinism check
|
||||
let res_1 = match_value_or_executor.execute((&operand.c, &match_values, or_value));
|
||||
|
||||
let decrypted_res: u64 = cks.decrypt(&res);
|
||||
|
||||
let res_casted = sks.cast_to_unsigned(res.clone(), NB_CTXT_LONG_RUN);
|
||||
datagen.put_op_result_random_side(expected_value, &res_casted, fn_name, idx);
|
||||
|
||||
sanity_check_op_sequence_result_u64(
|
||||
idx,
|
||||
fn_name,
|
||||
fn_index,
|
||||
&res,
|
||||
&res_1,
|
||||
decrypted_res,
|
||||
expected_value,
|
||||
operand.p,
|
||||
operand.p,
|
||||
);
|
||||
} else if oprf_ops_range.contains(&i) {
|
||||
let index = i - oprf_ops_range.start;
|
||||
let (op_executor, fn_name) = &mut oprf_ops[index];
|
||||
|
||||
Reference in New Issue
Block a user