refactor(gpu): unchecked_match_value_or to backend

This commit is contained in:
Enzo Di Maria
2025-11-20 16:00:16 +01:00
committed by Agnès Leroy
parent 184f40439e
commit 32b1a7ab1d
20 changed files with 1396 additions and 99 deletions

View File

@@ -75,3 +75,55 @@ template <typename Torus> struct int_extend_radix_with_sign_msb_buffer {
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};
template <typename Torus> struct int_cast_to_unsigned_buffer {
int_radix_params params;
bool allocate_gpu_memory;
bool requires_full_propagate;
bool requires_sign_extension;
int_fullprop_buffer<Torus> *prop_buffer;
int_extend_radix_with_sign_msb_buffer<Torus> *extend_buffer;
int_cast_to_unsigned_buffer(CudaStreams streams, int_radix_params params,
uint32_t num_input_blocks,
uint32_t target_num_blocks, bool input_is_signed,
bool requires_full_propagate,
bool allocate_gpu_memory,
uint64_t &size_tracker) {
this->params = params;
this->allocate_gpu_memory = allocate_gpu_memory;
this->requires_full_propagate = requires_full_propagate;
this->prop_buffer = nullptr;
this->extend_buffer = nullptr;
if (requires_full_propagate) {
this->prop_buffer = new int_fullprop_buffer<Torus>(
streams, params, allocate_gpu_memory, size_tracker);
}
this->requires_sign_extension =
(target_num_blocks > num_input_blocks) && input_is_signed;
if (this->requires_sign_extension) {
uint32_t num_blocks_to_add = target_num_blocks - num_input_blocks;
this->extend_buffer = new int_extend_radix_with_sign_msb_buffer<Torus>(
streams, params, num_input_blocks, num_blocks_to_add,
allocate_gpu_memory, size_tracker);
}
}
void release(CudaStreams streams) {
if (this->prop_buffer) {
this->prop_buffer->release(streams);
delete this->prop_buffer;
}
if (this->extend_buffer) {
this->extend_buffer->release(streams);
delete this->extend_buffer;
}
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};

View File

@@ -569,6 +569,10 @@ void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI const *input,
CudaStreamsFFI streams);
void trim_radix_blocks_msb_64(CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI const *input,
CudaStreamsFFI streams);
uint64_t scratch_cuda_apply_noise_squashing(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size,
@@ -850,6 +854,46 @@ void cuda_unchecked_match_value_64(
void cleanup_cuda_unchecked_match_value_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void);
uint64_t scratch_cuda_cast_to_unsigned_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed,
bool requires_full_propagate, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI *input, int8_t *mem_ptr,
uint32_t target_num_blocks, bool input_is_signed,
void *const *bsks, void *const *ksks);
void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void);
uint64_t scratch_cuda_unchecked_match_value_or_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_matches, uint32_t num_input_blocks,
uint32_t num_match_packed_blocks, uint32_t num_final_blocks,
uint32_t max_output_is_zero, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_unchecked_match_value_or_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in_ct,
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
const uint64_t *h_or_value, int8_t *mem, void *const *bsks,
void *const *ksks);
void cleanup_cuda_unchecked_match_value_or_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void);
} // extern C
#endif // CUDA_INTEGER_H

View File

@@ -1,4 +1,5 @@
#pragma once
#include "cast.h"
#include "integer/comparison.h"
#include "integer/radix_ciphertext.cuh"
#include "integer_utilities.h"
@@ -593,3 +594,98 @@ template <typename Torus> struct int_unchecked_match_buffer {
delete this->packed_selectors_ct;
}
};
template <typename Torus> struct int_unchecked_match_value_or_buffer {
int_radix_params params;
bool allocate_gpu_memory;
uint32_t num_matches;
uint32_t num_input_blocks;
uint32_t num_match_packed_blocks;
uint32_t num_final_blocks;
bool max_output_is_zero;
int_unchecked_match_buffer<Torus> *match_buffer;
int_cmux_buffer<Torus> *cmux_buffer;
CudaRadixCiphertextFFI *tmp_match_result;
CudaRadixCiphertextFFI *tmp_match_bool;
CudaRadixCiphertextFFI *tmp_or_value;
Torus *d_or_value;
int_unchecked_match_value_or_buffer(
CudaStreams streams, int_radix_params params, uint32_t num_matches,
uint32_t num_input_blocks, uint32_t num_match_packed_blocks,
uint32_t num_final_blocks, bool max_output_is_zero,
bool allocate_gpu_memory, uint64_t &size_tracker) {
this->params = params;
this->allocate_gpu_memory = allocate_gpu_memory;
this->num_matches = num_matches;
this->num_input_blocks = num_input_blocks;
this->num_match_packed_blocks = num_match_packed_blocks;
this->num_final_blocks = num_final_blocks;
this->max_output_is_zero = max_output_is_zero;
this->match_buffer = new int_unchecked_match_buffer<Torus>(
streams, params, num_matches, num_input_blocks, num_match_packed_blocks,
max_output_is_zero, allocate_gpu_memory, size_tracker);
this->cmux_buffer = new int_cmux_buffer<Torus>(
streams, [](Torus x) -> Torus { return x == 1; }, params,
num_final_blocks, allocate_gpu_memory, size_tracker);
this->tmp_match_result = new CudaRadixCiphertextFFI;
this->tmp_match_bool = new CudaRadixCiphertextFFI;
this->tmp_or_value = new CudaRadixCiphertextFFI;
this->d_or_value = (Torus *)cuda_malloc_with_size_tracking_async(
num_final_blocks * sizeof(Torus), streams.stream(0),
streams.gpu_index(0), size_tracker, allocate_gpu_memory);
if (!max_output_is_zero) {
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->tmp_match_result,
num_final_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
}
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->tmp_match_bool, 1,
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
create_zero_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), this->tmp_or_value,
num_final_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
}
void release(CudaStreams streams) {
this->match_buffer->release(streams);
delete this->match_buffer;
this->cmux_buffer->release(streams);
delete this->cmux_buffer;
if (!max_output_is_zero) {
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->tmp_match_result,
this->allocate_gpu_memory);
}
delete this->tmp_match_result;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->tmp_match_bool,
this->allocate_gpu_memory);
delete this->tmp_match_bool;
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->tmp_or_value,
this->allocate_gpu_memory);
delete this->tmp_or_value;
cuda_drop_async(this->d_or_value, streams.stream(0), streams.gpu_index(0));
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};

View File

@@ -18,6 +18,15 @@ void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output,
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
}
void trim_radix_blocks_msb_64(CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI const *input,
CudaStreamsFFI streams) {
auto cuda_streams = CudaStreams(streams);
host_trim_radix_blocks_msb<uint64_t>(output, input, cuda_streams);
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
}
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
@@ -64,3 +73,46 @@ void cleanup_cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
delete mem_ptr;
*mem_ptr_void = nullptr;
}
uint64_t scratch_cuda_cast_to_unsigned_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed,
bool requires_full_propagate, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
return scratch_cuda_cast_to_unsigned<uint64_t>(
CudaStreams(streams), (int_cast_to_unsigned_buffer<uint64_t> **)mem_ptr,
params, num_input_blocks, target_num_blocks, input_is_signed,
requires_full_propagate, allocate_gpu_memory);
}
void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI *input, int8_t *mem_ptr,
uint32_t target_num_blocks, bool input_is_signed,
void *const *bsks, void *const *ksks) {
host_cast_to_unsigned<uint64_t>(
CudaStreams(streams), output, input,
(int_cast_to_unsigned_buffer<uint64_t> *)mem_ptr, target_num_blocks,
input_is_signed, bsks, (uint64_t **)ksks);
}
void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void) {
int_cast_to_unsigned_buffer<uint64_t> *mem_ptr =
(int_cast_to_unsigned_buffer<uint64_t> *)(*mem_ptr_void);
mem_ptr->release(CudaStreams(streams));
delete mem_ptr;
*mem_ptr_void = nullptr;
}

View File

@@ -36,6 +36,23 @@ __host__ void host_trim_radix_blocks_lsb(CudaRadixCiphertextFFI *output,
input->num_radix_blocks);
}
template <typename Torus>
__host__ void
host_trim_radix_blocks_msb(CudaRadixCiphertextFFI *output_radix,
const CudaRadixCiphertextFFI *input_radix,
CudaStreams streams) {
PANIC_IF_FALSE(input_radix->num_radix_blocks >=
output_radix->num_radix_blocks,
"Cuda error: input radix ciphertext has fewer blocks than "
"required to keep");
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), output_radix, 0,
output_radix->num_radix_blocks, input_radix, 0,
output_radix->num_radix_blocks);
}
template <typename Torus>
__host__ uint64_t scratch_extend_radix_with_sign_msb(
CudaStreams streams, int_extend_radix_with_sign_msb_buffer<Torus> **mem_ptr,
@@ -91,4 +108,56 @@ __host__ void host_extend_radix_with_sign_msb(
POP_RANGE()
}
template <typename Torus>
uint64_t scratch_cuda_cast_to_unsigned(
CudaStreams streams, int_cast_to_unsigned_buffer<Torus> **mem_ptr,
int_radix_params params, uint32_t num_input_blocks,
uint32_t target_num_blocks, bool input_is_signed,
bool requires_full_propagate, bool allocate_gpu_memory) {
uint64_t size_tracker = 0;
*mem_ptr = new int_cast_to_unsigned_buffer<Torus>(
streams, params, num_input_blocks, target_num_blocks, input_is_signed,
requires_full_propagate, allocate_gpu_memory, size_tracker);
return size_tracker;
}
template <typename Torus>
__host__ void
host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output,
CudaRadixCiphertextFFI *input,
int_cast_to_unsigned_buffer<Torus> *mem_ptr,
uint32_t target_num_blocks, bool input_is_signed,
void *const *bsks, Torus *const *ksks) {
uint32_t current_num_blocks = input->num_radix_blocks;
if (mem_ptr->requires_full_propagate) {
host_full_propagate_inplace<Torus>(streams, input, mem_ptr->prop_buffer,
ksks, bsks, current_num_blocks);
}
if (target_num_blocks > current_num_blocks) {
uint32_t num_blocks_to_add = target_num_blocks - current_num_blocks;
if (input_is_signed) {
host_extend_radix_with_sign_msb<Torus>(
streams, output, input, mem_ptr->extend_buffer, num_blocks_to_add,
bsks, (Torus **)ksks);
} else {
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(output, input,
streams);
}
} else if (target_num_blocks < current_num_blocks) {
host_trim_radix_blocks_msb<Torus>(output, input, streams);
} else {
copy_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), output, 0, current_num_blocks,
input, 0, current_num_blocks);
}
}
#endif

View File

@@ -173,3 +173,51 @@ void cleanup_cuda_unchecked_match_value_64(CudaStreamsFFI streams,
delete mem_ptr;
*mem_ptr_void = nullptr;
}
uint64_t scratch_cuda_unchecked_match_value_or_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_matches, uint32_t num_input_blocks,
uint32_t num_match_packed_blocks, uint32_t num_final_blocks,
uint32_t max_output_is_zero, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
return scratch_cuda_unchecked_match_value_or<uint64_t>(
CudaStreams(streams),
(int_unchecked_match_value_or_buffer<uint64_t> **)mem_ptr, params,
num_matches, num_input_blocks, num_match_packed_blocks, num_final_blocks,
max_output_is_zero, allocate_gpu_memory);
}
void cuda_unchecked_match_value_or_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in_ct,
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
const uint64_t *h_or_value, int8_t *mem, void *const *bsks,
void *const *ksks) {
host_unchecked_match_value_or<uint64_t>(
CudaStreams(streams), lwe_array_out, lwe_array_in_ct, h_match_inputs,
h_match_outputs, h_or_value,
(int_unchecked_match_value_or_buffer<uint64_t> *)mem, bsks,
(uint64_t *const *)ksks);
}
void cleanup_cuda_unchecked_match_value_or_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void) {
int_unchecked_match_value_or_buffer<uint64_t> *mem_ptr =
(int_unchecked_match_value_or_buffer<uint64_t> *)(*mem_ptr_void);
mem_ptr->release(CudaStreams(streams));
delete mem_ptr;
*mem_ptr_void = nullptr;
}

View File

@@ -1,5 +1,7 @@
#pragma once
#include "integer/cast.cuh"
#include "integer/cmux.cuh"
#include "integer/comparison.cuh"
#include "integer/integer.cuh"
#include "integer/radix_ciphertext.cuh"
@@ -453,3 +455,46 @@ uint64_t scratch_cuda_unchecked_match_value(
return size_tracker;
}
template <typename Torus>
uint64_t scratch_cuda_unchecked_match_value_or(
CudaStreams streams, int_unchecked_match_value_or_buffer<Torus> **mem_ptr,
int_radix_params params, uint32_t num_matches, uint32_t num_input_blocks,
uint32_t num_match_packed_blocks, uint32_t num_final_blocks,
bool max_output_is_zero, bool allocate_gpu_memory) {
uint64_t size_tracker = 0;
*mem_ptr = new int_unchecked_match_value_or_buffer<Torus>(
streams, params, num_matches, num_input_blocks, num_match_packed_blocks,
num_final_blocks, max_output_is_zero, allocate_gpu_memory, size_tracker);
return size_tracker;
}
template <typename Torus>
__host__ void host_unchecked_match_value_or(
CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in_ct,
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
const uint64_t *h_or_value,
int_unchecked_match_value_or_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
host_unchecked_match_value<Torus>(streams, mem_ptr->tmp_match_result,
mem_ptr->tmp_match_bool, lwe_array_in_ct,
h_match_inputs, h_match_outputs,
mem_ptr->match_buffer, bsks, ksks);
cuda_memcpy_async_to_gpu(mem_ptr->d_or_value, h_or_value,
mem_ptr->num_final_blocks * sizeof(Torus),
streams.stream(0), streams.gpu_index(0));
set_trivial_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), mem_ptr->tmp_or_value,
mem_ptr->d_or_value, (Torus *)h_or_value, mem_ptr->num_final_blocks,
mem_ptr->params.message_modulus, mem_ptr->params.carry_modulus);
host_cmux<Torus>(streams, lwe_array_out, mem_ptr->tmp_match_bool,
mem_ptr->tmp_match_result, mem_ptr->tmp_or_value,
mem_ptr->cmux_buffer, bsks, (Torus **)ksks);
}

View File

@@ -1270,6 +1270,13 @@ unsafe extern "C" {
streams: CudaStreamsFFI,
);
}
unsafe extern "C" {
pub fn trim_radix_blocks_msb_64(
output: *mut CudaRadixCiphertextFFI,
input: *const CudaRadixCiphertextFFI,
streams: CudaStreamsFFI,
);
}
unsafe extern "C" {
pub fn scratch_cuda_apply_noise_squashing(
streams: CudaStreamsFFI,
@@ -1872,6 +1879,89 @@ unsafe extern "C" {
mem_ptr_void: *mut *mut i8,
);
}
unsafe extern "C" {
pub fn scratch_cuda_cast_to_unsigned_64(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
num_input_blocks: u32,
target_num_blocks: u32,
input_is_signed: bool,
requires_full_propagate: bool,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_cast_to_unsigned_64(
streams: CudaStreamsFFI,
output: *mut CudaRadixCiphertextFFI,
input: *mut CudaRadixCiphertextFFI,
mem_ptr: *mut i8,
target_num_blocks: u32,
input_is_signed: bool,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_cast_to_unsigned_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}
unsafe extern "C" {
pub fn scratch_cuda_unchecked_match_value_or_64(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
num_matches: u32,
num_input_blocks: u32,
num_match_packed_blocks: u32,
num_final_blocks: u32,
max_output_is_zero: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_unchecked_match_value_or_64(
streams: CudaStreamsFFI,
lwe_array_out: *mut CudaRadixCiphertextFFI,
lwe_array_in_ct: *const CudaRadixCiphertextFFI,
h_match_inputs: *const u64,
h_match_outputs: *const u64,
h_or_value: *const u64,
mem: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_unchecked_match_value_or_64(
streams: CudaStreamsFFI,
mem_ptr_void: *mut *mut i8,
);
}
unsafe extern "C" {
pub fn scratch_cuda_integer_compress_radix_ciphertext_64(
streams: CudaStreamsFFI,

View File

@@ -1363,6 +1363,63 @@ where
})
}
/// Returns the estimated memory usage (in bytes) required on the GPU to perform
/// the `match_value_or` operation.
///
/// This is useful to check if the operation fits in the GPU memory before attempting execution.
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16, MatchValues};
///
/// let config = ConfigBuilder::default().build();
/// let (client_key, server_key) = generate_keys(config);
/// set_server_key(server_key);
///
/// let a = FheUint16::encrypt(17u16, &client_key);
///
/// let match_values = MatchValues::new(vec![(0u16, 3u16), (17u16, 25u16)]).unwrap();
///
/// #[cfg(feature = "gpu")]
/// {
/// let size_bytes = a
/// .get_match_value_or_size_on_gpu(&match_values, 55u16)
/// .unwrap();
/// println!("Memory required on GPU: {} bytes", size_bytes);
/// assert!(size_bytes > 0);
/// }
/// ```
#[cfg(feature = "gpu")]
pub fn get_match_value_or_size_on_gpu<Clear>(
&self,
matches: &MatchValues<Clear>,
or_value: Clear,
) -> crate::Result<u64>
where
Clear: UnsignedInteger + DecomposableInto<u64> + CastInto<usize>,
{
global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(_) => Err(crate::Error::new(
"This function is only available when using the CUDA backend".to_string(),
)),
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let ct_on_gpu = self.ciphertext.on_gpu(streams);
let size = cuda_key.key.key.get_unchecked_match_value_or_size_on_gpu(
&ct_on_gpu, matches, or_value, streams,
);
Ok(size)
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support this operation.")
}
})
}
/// Reverse the bit of the unsigned integer
///
/// # Example

View File

@@ -653,3 +653,9 @@ fn test_compressed_cpk_encrypt_cast_compute_hl() {
let clear: u64 = mul.decrypt(&client_key);
assert_eq!(clear, (input_msg * multiplier) % modulus);
}
#[test]
fn test_match_value_or() {
let client_key = setup_default_cpu();
super::test_case_match_value_or(&client_key);
}

View File

@@ -960,3 +960,42 @@ fn test_gpu_get_match_value_size_on_gpu() {
assert!(memory_size > 0);
}
}
#[test]
fn test_match_value_or_gpu() {
let client_key = setup_classical_gpu();
super::test_case_match_value_or(&client_key);
}
#[test]
fn test_match_value_or_gpu_multibit() {
let client_key = setup_multibit_gpu();
super::test_case_match_value_or(&client_key);
}
#[test]
fn test_gpu_get_match_value_or_size_on_gpu() {
for setup_fn in GPU_SETUP_FN {
let cks = setup_fn();
let mut rng = rand::thread_rng();
let clear_a = rng.gen::<u32>();
let or_value = rng.gen::<u32>();
let mut a = FheUint32::try_encrypt(clear_a, &cks).unwrap();
a.move_to_current_device();
let match_values = MatchValues::new(vec![
(0u32, 10u32),
(1u32, 20u32),
(clear_a, 30u32),
(u32::MAX, 40u32),
])
.unwrap();
let memory_size = a
.get_match_value_or_size_on_gpu(&match_values, or_value)
.unwrap();
check_valid_cuda_malloc_assert_oom(memory_size, GpuIndex::new(0));
assert!(memory_size > 0);
}
}

View File

@@ -839,3 +839,52 @@ fn test_case_match_value(cks: &ClientKey) {
}
}
}
fn test_case_match_value_or(cks: &ClientKey) {
let mut rng = thread_rng();
for _ in 0..5 {
let clear_in = rng.gen::<u8>();
let ct = FheUint8::encrypt(clear_in, cks);
let clear_or_value = rng.gen::<u8>();
let should_match = rng.gen_bool(0.5);
let mut map: HashMap<u8, u8> = HashMap::new();
let mut pairs = Vec::new();
let expected_value = if should_match {
let val = rng.gen::<u8>();
map.insert(clear_in, val);
pairs.push((clear_in, val));
val
} else {
clear_or_value
};
let num_entries = rng.gen_range(1..10);
for _ in 0..num_entries {
let mut k = rng.gen::<u8>();
while !should_match && k == clear_in {
k = rng.gen::<u8>();
}
if let std::collections::hash_map::Entry::Vacant(e) = map.entry(k) {
let v = rng.gen::<u8>();
e.insert(v);
pairs.push((k, v));
}
}
let matches = MatchValues::new(pairs).unwrap();
let result: FheUint8 = ct.match_value_or(&matches, clear_or_value).unwrap();
let dec_result: u8 = result.decrypt(cks);
assert_eq!(
dec_result, expected_value,
"Mismatch on result value for input {clear_in}. Should match: {should_match}"
);
}
}

View File

@@ -71,18 +71,6 @@ impl CudaRadixCiphertextInfo {
new_block_info
}
pub(crate) fn after_trim_radix_blocks_msb(&self, num_blocks: usize) -> Self {
assert!(num_blocks > 0);
let mut new_block_info = Self {
blocks: Vec::with_capacity(self.blocks.len().saturating_sub(num_blocks)),
};
new_block_info
.blocks
.extend(self.blocks[..num_blocks].iter().copied());
new_block_info
}
pub fn duplicate(&self) -> Self {
Self {
blocks: self

View File

@@ -7504,6 +7504,34 @@ pub(crate) unsafe fn cuda_backend_trim_radix_blocks_lsb(
update_noise_degree(output, &cuda_ffi_output);
}
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_trim_radix_blocks_msb(
output: &mut CudaRadixCiphertext,
input: &CudaRadixCiphertext,
streams: &CudaStreams,
) {
let mut input_degrees = input.info.blocks.iter().map(|b| b.degree.0).collect();
let mut input_noise_levels = input.info.blocks.iter().map(|b| b.noise_level.0).collect();
let mut output_degrees = output.info.blocks.iter().map(|b| b.degree.0).collect();
let mut output_noise_levels = output.info.blocks.iter().map(|b| b.noise_level.0).collect();
let mut cuda_ffi_output =
prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels);
let cuda_ffi_input = prepare_cuda_radix_ffi(input, &mut input_degrees, &mut input_noise_levels);
trim_radix_blocks_msb_64(
&raw mut cuda_ffi_output,
&raw const cuda_ffi_input,
streams.ffi(),
);
update_noise_degree(output, &cuda_ffi_output);
}
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
@@ -8518,6 +8546,11 @@ pub(crate) unsafe fn cuda_backend_compute_equality_selectors<T: UnsignedInteger,
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_create_possible_results<
T: UnsignedInteger,
B: Numeric,
@@ -8654,6 +8687,11 @@ pub(crate) unsafe fn cuda_backend_create_possible_results<
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_aggregate_one_hot_vector<
T: UnsignedInteger,
B: Numeric,
@@ -8790,6 +8828,11 @@ pub(crate) unsafe fn cuda_backend_aggregate_one_hot_vector<
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_unchecked_match_value<
T: UnsignedInteger,
B: Numeric,
@@ -8945,6 +8988,11 @@ pub(crate) unsafe fn cuda_backend_unchecked_match_value<
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) fn cuda_backend_get_unchecked_match_value_size_on_gpu(
streams: &CudaStreams,
glwe_dimension: GlweDimension,
@@ -8999,3 +9047,276 @@ pub(crate) fn cuda_backend_get_unchecked_match_value_size_on_gpu(
size_tracker
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) fn cuda_backend_get_unchecked_match_value_or_size_on_gpu(
streams: &CudaStreams,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
grouping_factor: LweBskGroupingFactor,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
pbs_type: PBSType,
num_matches: u32,
num_input_blocks: u32,
num_match_packed_blocks: u32,
num_output_blocks: u32,
max_output_is_zero: bool,
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
) -> u64 {
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size_tracker = unsafe {
scratch_cuda_unchecked_match_value_or_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_matches,
num_input_blocks,
num_match_packed_blocks,
num_output_blocks,
max_output_is_zero as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
false,
noise_reduction_type as u32,
)
};
unsafe {
cleanup_cuda_unchecked_match_value_or_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr))
};
size_tracker
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_cast_to_unsigned<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
output: &mut CudaRadixCiphertext,
input: &mut CudaRadixCiphertext,
input_is_signed: bool,
requires_full_propagate: bool,
target_num_blocks: u32,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
) {
let message_modulus = input.info.blocks.first().unwrap().message_modulus;
let carry_modulus = input.info.blocks.first().unwrap().carry_modulus;
let num_input_blocks = input.d_blocks.lwe_ciphertext_count().0 as u32;
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
let mut input_degrees: Vec<u64> = input.info.blocks.iter().map(|b| b.degree.0).collect();
let mut input_noise_levels: Vec<u64> =
input.info.blocks.iter().map(|b| b.noise_level.0).collect();
let mut cuda_ffi_input =
prepare_cuda_radix_ffi(input, &mut input_degrees, &mut input_noise_levels);
let mut output_degrees: Vec<u64> = output.info.blocks.iter().map(|b| b.degree.0).collect();
let mut output_noise_levels: Vec<u64> =
output.info.blocks.iter().map(|b| b.noise_level.0).collect();
let mut cuda_ffi_output =
prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels);
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_cast_to_unsigned_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_input_blocks,
target_num_blocks,
input_is_signed,
requires_full_propagate,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
);
cuda_cast_to_unsigned_64(
streams.ffi(),
&raw mut cuda_ffi_output,
&raw mut cuda_ffi_input,
mem_ptr,
target_num_blocks,
input_is_signed,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
cleanup_cuda_cast_to_unsigned_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(output, &cuda_ffi_output);
if requires_full_propagate {
update_noise_degree(input, &cuda_ffi_input);
}
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_unchecked_match_value_or<
T: UnsignedInteger,
B: Numeric,
R: CudaIntegerRadixCiphertext,
>(
streams: &CudaStreams,
lwe_array_out: &mut R,
lwe_array_in_ct: &CudaRadixCiphertext,
h_match_inputs: &[u64],
h_match_outputs: &[u64],
h_or_value: &[u64],
num_matches: u32,
num_input_blocks: u32,
num_match_packed_blocks: u32,
num_final_blocks: u32,
max_output_is_zero: bool,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
) {
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
let mut ffi_out_degrees: Vec<u64> = lwe_array_out
.as_ref()
.info
.blocks
.iter()
.map(|b| b.degree.get())
.collect();
let mut ffi_out_noise_levels: Vec<u64> = lwe_array_out
.as_ref()
.info
.blocks
.iter()
.map(|b| b.noise_level.0)
.collect();
let mut ffi_out_struct = prepare_cuda_radix_ffi(
lwe_array_out.as_ref(),
&mut ffi_out_degrees,
&mut ffi_out_noise_levels,
);
let mut ffi_in_ct_degrees: Vec<u64> = lwe_array_in_ct
.info
.blocks
.iter()
.map(|b| b.degree.get())
.collect();
let mut ffi_in_ct_noise_levels: Vec<u64> = lwe_array_in_ct
.info
.blocks
.iter()
.map(|b| b.noise_level.0)
.collect();
let ffi_in_ct_struct = prepare_cuda_radix_ffi(
lwe_array_in_ct,
&mut ffi_in_ct_degrees,
&mut ffi_in_ct_noise_levels,
);
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_unchecked_match_value_or_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_matches,
num_input_blocks,
num_match_packed_blocks,
num_final_blocks,
max_output_is_zero as u32,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
);
cuda_unchecked_match_value_or_64(
streams.ffi(),
&raw mut ffi_out_struct,
&raw const ffi_in_ct_struct,
h_match_inputs.as_ptr(),
h_match_outputs.as_ptr(),
h_or_value.as_ptr(),
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
cleanup_cuda_unchecked_match_value_or_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(lwe_array_out.as_mut(), &ffi_out_struct);
}

View File

@@ -17,11 +17,11 @@ use crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{
cuda_backend_apply_bivariate_lut, cuda_backend_apply_many_univariate_lut,
cuda_backend_apply_univariate_lut, cuda_backend_compute_prefix_sum_hillis_steele,
cuda_backend_extend_radix_with_sign_msb,
cuda_backend_apply_univariate_lut, cuda_backend_cast_to_unsigned,
cuda_backend_compute_prefix_sum_hillis_steele, cuda_backend_extend_radix_with_sign_msb,
cuda_backend_extend_radix_with_trivial_zero_blocks_msb, cuda_backend_full_propagate_assign,
cuda_backend_noise_squashing, cuda_backend_propagate_single_carry_assign,
cuda_backend_trim_radix_blocks_lsb, CudaServerKey, PBSType,
cuda_backend_trim_radix_blocks_lsb, cuda_backend_trim_radix_blocks_msb, CudaServerKey, PBSType,
};
use crate::integer::server_key::radix_parallel::OutputFlag;
use crate::shortint::ciphertext::{Degree, NoiseLevel};
@@ -584,31 +584,19 @@ impl CudaServerKey {
num_blocks: usize,
streams: &CudaStreams,
) -> T {
let total_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0;
if num_blocks == 0 {
return ct.duplicate(streams);
}
let new_num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 - num_blocks;
let ciphertext_modulus = ct.as_ref().d_blocks.ciphertext_modulus();
let lwe_size = ct.as_ref().d_blocks.lwe_dimension().to_lwe_size();
let shift = new_num_blocks * lwe_size.0;
let new_num_blocks = total_blocks - num_blocks;
let mut trimmed_ct: T = self.create_trivial_zero_radix(new_num_blocks, streams);
let mut trimmed_ct_vec = CudaVec::new(new_num_blocks * lwe_size.0, streams, 0);
unsafe {
trimmed_ct_vec.copy_src_range_gpu_to_gpu_async(
0..shift,
&ct.as_ref().d_blocks.0.d_vec,
streams,
0,
);
streams.synchronize();
cuda_backend_trim_radix_blocks_msb(trimmed_ct.as_mut(), ct.as_ref(), streams);
}
let trimmed_ct_list = CudaLweCiphertextList::from_cuda_vec(
trimmed_ct_vec,
LweCiphertextCount(new_num_blocks),
ciphertext_modulus,
);
let trimmed_ct_info = ct.as_ref().info.after_trim_radix_blocks_msb(new_num_blocks);
T::from(CudaRadixCiphertext::new(trimmed_ct_list, trimmed_ct_info))
trimmed_ct
}
pub(crate) fn generate_lookup_table<F>(&self, f: F) -> LookupTableOwned
@@ -1339,48 +1327,71 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
if !source.block_carries_are_empty() {
self.full_propagate_assign(&mut source, streams);
}
let current_num_blocks = source.as_ref().info.blocks.len();
if T::IS_SIGNED {
// Casting from signed to unsigned
// We have to trim or sign extend first
if target_num_blocks > current_num_blocks {
let num_blocks_to_add = target_num_blocks - current_num_blocks;
let signed_res: T =
self.extend_radix_with_sign_msb(&source, num_blocks_to_add, streams);
<CudaUnsignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
signed_res.into_inner(),
)
} else {
let num_blocks_to_remove = current_num_blocks - target_num_blocks;
let signed_res = self.trim_radix_blocks_msb(&source, num_blocks_to_remove, streams);
<CudaUnsignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
signed_res.into_inner(),
)
}
} else {
// Casting from unsigned to unsigned, this is just about trimming/extending with zeros
if target_num_blocks > current_num_blocks {
let num_blocks_to_add = target_num_blocks - current_num_blocks;
let unsigned_res = self.extend_radix_with_trivial_zero_blocks_msb(
&source,
num_blocks_to_add,
streams,
);
<CudaUnsignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
unsigned_res.into_inner(),
)
} else {
let num_blocks_to_remove = current_num_blocks - target_num_blocks;
let unsigned_res =
self.trim_radix_blocks_msb(&source, num_blocks_to_remove, streams);
<CudaUnsignedRadixCiphertext as CudaIntegerRadixCiphertext>::from(
unsigned_res.into_inner(),
)
let mut result: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(target_num_blocks, streams);
let requires_full_propagate = !source.block_carries_are_empty();
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_cast_to_unsigned(
streams,
result.as_mut(),
source.as_mut(),
T::IS_SIGNED,
requires_full_propagate,
target_num_blocks as u32,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_cast_to_unsigned(
streams,
result.as_mut(),
source.as_mut(),
T::IS_SIGNED,
requires_full_propagate,
target_num_blocks as u32,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
result
}
/// Cast a `CudaUnsignedRadixCiphertext` or `CudaSignedRadixCiphertext` to a

View File

@@ -494,6 +494,46 @@ where
}
}
/// For match_value_or operation
impl<'a, F>
OpSequenceFunctionExecutor<(&'a RadixCiphertext, &'a MatchValues<u64>, u64), RadixCiphertext>
for OpSequenceGpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
&MatchValues<u64>,
u64,
&CudaStreams,
) -> CudaUnsignedRadixCiphertext,
{
fn setup(
&mut self,
cks: &RadixClientKey,
sks: &CompressedServerKey,
seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
) {
self.setup_from_gpu_keys(cks, sks, seeder);
}
fn execute(
&mut self,
input: (&'a RadixCiphertext, &'a MatchValues<u64>, u64),
) -> RadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, input.2, &context.streams);
d_res.to_radix_ciphertext(&context.streams)
}
}
impl<'a, F>
OpSequenceFunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext),

View File

@@ -5,9 +5,10 @@ use crate::integer::gpu::CudaServerKey;
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_long_run::test_random_op_sequence::{
random_op_sequence_test, BinaryOpExecutor, ComparisonOpExecutor, DivRemOpExecutor,
Log2OpExecutor, MatchValueExecutor, OprfBoundedExecutor, OprfCustomRangeExecutor, OprfExecutor,
OverflowingOpExecutor, ScalarBinaryOpExecutor, ScalarComparisonOpExecutor,
ScalarDivRemOpExecutor, ScalarOverflowingOpExecutor, SelectOpExecutor, UnaryOpExecutor,
Log2OpExecutor, MatchValueExecutor, MatchValueOrExecutor, OprfBoundedExecutor,
OprfCustomRangeExecutor, OprfExecutor, OverflowingOpExecutor, ScalarBinaryOpExecutor,
ScalarComparisonOpExecutor, ScalarDivRemOpExecutor, ScalarOverflowingOpExecutor,
SelectOpExecutor, UnaryOpExecutor,
};
use crate::integer::server_key::radix_parallel::tests_long_run::{
get_user_defined_seed, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN,
@@ -54,6 +55,7 @@ pub(crate) fn random_op_sequence_test_init_gpu<P>(
)],
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
match_value_ops: &mut [(MatchValueExecutor, String)],
match_value_or_ops: &mut [(MatchValueOrExecutor, String)],
oprf_ops: &mut [(OprfExecutor, String)],
oprf_bounded_ops: &mut [(OprfBoundedExecutor, String)],
oprf_custom_range_ops: &mut [(OprfCustomRangeExecutor, String)],
@@ -88,6 +90,7 @@ where
+ scalar_div_rem_op.len()
+ log2_ops.len()
+ match_value_ops.len()
+ match_value_or_ops.len()
+ oprf_ops.len()
+ oprf_bounded_ops.len()
+ oprf_custom_range_ops.len();
@@ -142,6 +145,9 @@ where
for x in match_value_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in match_value_or_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in oprf_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
@@ -317,6 +323,7 @@ where
)];
let mut match_value_ops: Vec<(MatchValueExecutor, String)> = vec![];
let mut match_value_or_ops: Vec<(MatchValueOrExecutor, String)> = vec![];
let mut oprf_ops: Vec<(OprfExecutor, String)> = vec![];
let mut oprf_bounded_ops: Vec<(OprfBoundedExecutor, String)> = vec![];
@@ -336,6 +343,7 @@ where
&mut scalar_div_rem_op,
&mut log2_ops,
&mut match_value_ops,
&mut match_value_or_ops,
&mut oprf_ops,
&mut oprf_bounded_ops,
&mut oprf_custom_range_ops,
@@ -357,6 +365,7 @@ where
&mut scalar_div_rem_op,
&mut log2_ops,
&mut match_value_ops,
&mut match_value_or_ops,
&mut oprf_ops,
&mut oprf_bounded_ops,
&mut oprf_custom_range_ops,
@@ -781,6 +790,14 @@ where
let mut match_value_ops: Vec<(MatchValueExecutor, String)> =
vec![(Box::new(match_value_executor), "match_value".to_string())];
// Match Value Or Executor
let match_value_or_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::match_value_or);
let mut match_value_or_ops: Vec<(MatchValueOrExecutor, String)> = vec![(
Box::new(match_value_or_executor),
"match_value_or".to_string(),
)];
// OPRF Executors
let oprf_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(
&CudaServerKey::par_generate_oblivious_pseudo_random_unsigned_integer,
@@ -816,6 +833,7 @@ where
&mut scalar_div_rem_op,
&mut log2_ops,
&mut match_value_ops,
&mut match_value_or_ops,
&mut oprf_ops,
&mut oprf_bounded_ops,
&mut oprf_custom_range_ops,
@@ -837,6 +855,7 @@ where
&mut scalar_div_rem_op,
&mut log2_ops,
&mut match_value_ops,
&mut match_value_or_ops,
&mut oprf_ops,
&mut oprf_bounded_ops,
&mut oprf_custom_range_ops,

View File

@@ -9,8 +9,9 @@ use crate::integer::gpu::server_key::radix::CudaRadixCiphertext;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{
cuda_backend_aggregate_one_hot_vector, cuda_backend_compute_equality_selectors,
cuda_backend_create_possible_results, cuda_backend_get_unchecked_match_value_size_on_gpu,
cuda_backend_unchecked_match_value, PBSType,
cuda_backend_create_possible_results, cuda_backend_get_unchecked_match_value_or_size_on_gpu,
cuda_backend_get_unchecked_match_value_size_on_gpu, cuda_backend_unchecked_match_value,
cuda_backend_unchecked_match_value_or, PBSType,
};
pub use crate::integer::server_key::radix_parallel::MatchValues;
use crate::prelude::CastInto;
@@ -410,26 +411,229 @@ impl CudaServerKey {
Clear: UnsignedInteger + DecomposableInto<u64> + CastInto<usize>,
{
if matches.get_values().is_empty() {
let ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(
or_value,
self.num_blocks_to_represent_unsigned_value(or_value),
streams,
);
let num_blocks = self.num_blocks_to_represent_unsigned_value(or_value);
let ct: CudaUnsignedRadixCiphertext =
self.create_trivial_radix(or_value, num_blocks, streams);
return ct;
}
let (result, selected) = self.unchecked_match_value(ct, matches, streams);
// The result must have as many block to represent either the result of the match or the
// or_value
let num_blocks_to_represent_or_value =
self.num_blocks_to_represent_unsigned_value(or_value);
let num_blocks = (result.as_ref().d_blocks.lwe_ciphertext_count().0)
.max(num_blocks_to_represent_or_value);
let or_value: CudaUnsignedRadixCiphertext =
self.create_trivial_radix(or_value, num_blocks, streams);
let casted_result = self.cast_to_unsigned(result, num_blocks, streams);
// Note, this could be slightly faster when we have scalar if then_else
self.unchecked_if_then_else(&selected, &casted_result, &or_value, streams)
let num_input_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
let num_bits_in_message = self.message_modulus.0.ilog2();
let h_match_inputs: Vec<u64> = matches
.get_values()
.par_iter()
.map(|(input, _output)| *input)
.flat_map(|input_value| {
BlockDecomposer::new(input_value, num_bits_in_message)
.take(num_input_blocks as usize)
.map(|block_value| block_value.cast_into())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let max_output_value_match = matches
.get_values()
.iter()
.copied()
.max_by(|(_, outputl), (_, outputr)| outputl.cmp(outputr))
.expect("luts is not empty at this point")
.1;
let num_blocks_match = self.num_blocks_to_represent_unsigned_value(max_output_value_match);
let num_blocks_or = self.num_blocks_to_represent_unsigned_value(or_value);
let final_num_blocks = num_blocks_match.max(num_blocks_or);
let num_match_packed_blocks = num_blocks_match.div_ceil(2) as u32;
let h_match_outputs: Vec<u64> = matches
.get_values()
.par_iter()
.map(|(_input, output)| *output)
.flat_map(|output_value| {
BlockDecomposer::new(output_value, 2 * num_bits_in_message)
.take(num_match_packed_blocks as usize)
.map(|block_value| block_value.cast_into())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let h_or_value: Vec<u64> = BlockDecomposer::new(or_value, num_bits_in_message)
.take(final_num_blocks)
.map(|block_value| block_value.cast_into())
.collect();
let mut result: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(final_num_blocks, streams);
let max_output_is_zero = max_output_value_match == Clear::ZERO;
let num_matches = matches.get_values().len() as u32;
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_match_value_or(
streams,
&mut result,
ct.as_ref(),
&h_match_inputs,
&h_match_outputs,
&h_or_value,
num_matches,
num_input_blocks,
num_match_packed_blocks,
final_num_blocks as u32,
max_output_is_zero,
self.message_modulus,
self.carry_modulus,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_match_value_or(
streams,
&mut result,
ct.as_ref(),
&h_match_inputs,
&h_match_outputs,
&h_or_value,
num_matches,
num_input_blocks,
num_match_packed_blocks,
final_num_blocks as u32,
max_output_is_zero,
self.message_modulus,
self.carry_modulus,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
result
}
pub fn get_unchecked_match_value_or_size_on_gpu<Clear>(
&self,
ct: &CudaUnsignedRadixCiphertext,
matches: &MatchValues<Clear>,
or_value: Clear,
streams: &CudaStreams,
) -> u64
where
Clear: UnsignedInteger + DecomposableInto<u64> + CastInto<usize>,
{
if matches.get_values().is_empty() {
return 0;
}
let num_input_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
let max_output_value_match = matches
.get_values()
.iter()
.copied()
.max_by(|(_, outputl), (_, outputr)| outputl.cmp(outputr))
.expect("luts is not empty at this point")
.1;
let num_blocks_match = self.num_blocks_to_represent_unsigned_value(max_output_value_match);
let num_blocks_or = self.num_blocks_to_represent_unsigned_value(or_value);
let final_num_blocks = num_blocks_match.max(num_blocks_or);
let num_match_packed_blocks = num_blocks_match.div_ceil(2) as u32;
let max_output_is_zero = max_output_value_match == Clear::ZERO;
let num_matches = matches.get_values().len() as u32;
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_unchecked_match_value_or_size_on_gpu(
streams,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
num_matches,
num_input_blocks,
num_match_packed_blocks,
final_num_blocks as u32,
max_output_is_zero,
d_bsk.ms_noise_reduction_configuration.as_ref(),
)
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_get_unchecked_match_value_or_size_on_gpu(
streams,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
num_matches,
num_input_blocks,
num_match_packed_blocks,
final_num_blocks as u32,
max_output_is_zero,
None,
)
}
}
}
/// `match` an input value to an output value

View File

@@ -267,9 +267,7 @@ impl<
}
}
}
}
impl RandomOpSequenceDataGenerator<u64, RadixCiphertext> {
#[allow(clippy::manual_is_multiple_of)]
pub(crate) fn gen_match_values(&mut self, key_to_match: u64) -> (MatchValues<u64>, u64, bool) {
let mut pairings = Vec::new();
@@ -297,6 +295,19 @@ impl RandomOpSequenceDataGenerator<u64, RadixCiphertext> {
does_match,
)
}
pub(crate) fn gen_match_values_or(
&mut self,
key_to_match: u64,
) -> (MatchValues<u64>, u64, u64) {
let (mv, match_val, does_match) = self.gen_match_values(key_to_match);
let or_value = self.deterministic_seeder.seed().0 as u64;
let expected = if does_match { match_val } else { or_value };
(mv, or_value, expected)
}
}
#[allow(clippy::too_many_arguments)]

View File

@@ -79,6 +79,13 @@ pub(crate) type MatchValueExecutor = Box<
>,
>;
pub(crate) type MatchValueOrExecutor = Box<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a RadixCiphertext, &'a MatchValues<u64>, u64),
RadixCiphertext,
>,
>;
pub(crate) type OprfExecutor =
Box<dyn for<'a> OpSequenceFunctionExecutor<(Seed, u64), RadixCiphertext>>;
@@ -474,6 +481,15 @@ where
let mut match_value_ops: Vec<(MatchValueExecutor, String)> =
vec![(Box::new(match_value_executor), "match_value".to_string())];
// Match Values Or Executor
let match_value_or_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::match_value_or_parallelized);
let mut match_value_or_ops: Vec<(MatchValueOrExecutor, String)> = vec![(
Box::new(match_value_or_executor),
"match_value_or".to_string(),
)];
// OPRF Executors
let oprf_executor = OpSequenceCpuFunctionExecutor::new(
&ServerKey::par_generate_oblivious_pseudo_random_unsigned_integer,
@@ -514,6 +530,7 @@ where
&mut scalar_div_rem_op,
&mut log2_ops,
&mut match_value_ops,
&mut match_value_or_ops,
&mut oprf_ops,
&mut oprf_bounded_ops,
&mut oprf_custom_range_ops,
@@ -535,6 +552,7 @@ where
&mut scalar_div_rem_op,
&mut log2_ops,
&mut match_value_ops,
&mut match_value_or_ops,
&mut oprf_ops,
&mut oprf_bounded_ops,
&mut oprf_custom_range_ops,
@@ -572,6 +590,7 @@ pub(crate) fn random_op_sequence_test_init_cpu<P>(
)],
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
match_value_ops: &mut [(MatchValueExecutor, String)],
match_value_or_ops: &mut [(MatchValueOrExecutor, String)],
oprf_ops: &mut [(OprfExecutor, String)],
oprf_bounded_ops: &mut [(OprfBoundedExecutor, String)],
oprf_custom_range_ops: &mut [(OprfCustomRangeExecutor, String)],
@@ -602,6 +621,7 @@ where
+ scalar_div_rem_op.len()
+ log2_ops.len()
+ match_value_ops.len()
+ match_value_or_ops.len()
+ oprf_ops.len()
+ oprf_bounded_ops.len()
+ oprf_custom_range_ops.len();
@@ -661,6 +681,9 @@ where
for x in match_value_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in match_value_or_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in oprf_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
@@ -706,6 +729,7 @@ pub(crate) fn random_op_sequence_test(
)],
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
match_value_ops: &mut [(MatchValueExecutor, String)],
match_value_or_ops: &mut [(MatchValueOrExecutor, String)],
oprf_ops: &mut [(OprfExecutor, String)],
oprf_bounded_ops: &mut [(OprfBoundedExecutor, String)],
oprf_custom_range_ops: &mut [(OprfCustomRangeExecutor, String)],
@@ -729,7 +753,10 @@ pub(crate) fn random_op_sequence_test(
div_rem_op_range.end..div_rem_op_range.end + scalar_div_rem_op.len();
let log2_ops_range = scalar_div_rem_op_range.end..scalar_div_rem_op_range.end + log2_ops.len();
let match_value_ops_range = log2_ops_range.end..log2_ops_range.end + match_value_ops.len();
let oprf_ops_range = match_value_ops_range.end..match_value_ops_range.end + oprf_ops.len();
let match_value_or_ops_range =
match_value_ops_range.end..match_value_ops_range.end + match_value_or_ops.len();
let oprf_ops_range =
match_value_or_ops_range.end..match_value_or_ops_range.end + oprf_ops.len();
let oprf_bounded_ops_range = oprf_ops_range.end..oprf_ops_range.end + oprf_bounded_ops.len();
let oprf_custom_range_ops_range =
oprf_bounded_ops_range.end..oprf_bounded_ops_range.end + oprf_custom_range_ops.len();
@@ -1149,6 +1176,35 @@ pub(crate) fn random_op_sequence_test(
operand.p,
operand.p,
);
} else if match_value_or_ops_range.contains(&i) {
let index = i - match_value_or_ops_range.start;
let (match_value_or_executor, fn_name) = &mut match_value_or_ops[index];
let operand = datagen.gen_op_single_operand(idx, fn_name);
let (match_values, or_value, expected_value) = datagen.gen_match_values_or(operand.p);
println!("{idx}: MatchValuesOr generated. Expected: {expected_value}");
let res = match_value_or_executor.execute((&operand.c, &match_values, or_value));
// Determinism check
let res_1 = match_value_or_executor.execute((&operand.c, &match_values, or_value));
let decrypted_res: u64 = cks.decrypt(&res);
let res_casted = sks.cast_to_unsigned(res.clone(), NB_CTXT_LONG_RUN);
datagen.put_op_result_random_side(expected_value, &res_casted, fn_name, idx);
sanity_check_op_sequence_result_u64(
idx,
fn_name,
fn_index,
&res,
&res_1,
decrypted_res,
expected_value,
operand.p,
operand.p,
);
} else if oprf_ops_range.contains(&i) {
let index = i - oprf_ops_range.start;
let (op_executor, fn_name) = &mut oprf_ops[index];