Compare commits

...

1 Commits

Author SHA1 Message Date
Guillermo Oyarzun
00ecb8306a feat(gpu): implement sub array search on GPU 2024-08-13 18:36:07 +02:00
5 changed files with 1441 additions and 17 deletions

View File

@@ -18,8 +18,10 @@ use crate::integer::gpu::{
CudaServerKey, PBSType,
};
use crate::shortint::ciphertext::{Degree, NoiseLevel};
use crate::shortint::engine::fill_accumulator;
use crate::shortint::server_key::{BivariateLookupTableOwned, LookupTableOwned};
use crate::shortint::engine::{fill_accumulator, fill_many_lut_accumulator};
use crate::shortint::server_key::{
BivariateLookupTableOwned, LookupTableOwned, ManyLookupTableOwned,
};
use crate::shortint::PBSOrder;
mod add;
@@ -41,6 +43,7 @@ mod scalar_shift;
mod scalar_sub;
mod shift;
mod sub;
mod vector_find;
#[cfg(test)]
mod tests_signed;
@@ -647,6 +650,36 @@ impl CudaServerKey {
T::from(CudaRadixCiphertext::new(trimmed_ct_list, trimmed_ct_info))
}
/*
pub fn generate_lookup_table<F>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ciphertext_modulus: CiphertextModulus,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
f: F,
) -> LookupTableOwned
where
F: Fn(u64) -> u64,
{
let mut acc = GlweCiphertext::new(0, glwe_size, polynomial_size, ciphertext_modulus);
let max_value = fill_accumulator(
&mut acc,
polynomial_size,
glwe_size,
message_modulus,
carry_modulus,
f,
);
LookupTableOwned {
acc,
degree: Degree::new(max_value as usize),
}
}
*/
pub(crate) fn generate_lookup_table<F>(&self, f: F) -> LookupTableOwned
where
F: Fn(u64) -> u64,
@@ -699,6 +732,38 @@ impl CudaServerKey {
}
}
pub(crate) fn generate_many_lookup_table<F>(&self, f: &[F]) -> ManyLookupTableOwned
where
F: Fn(u64) -> u64,
{
let (glwe_size, polynomial_size) = match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
(d_bsk.glwe_dimension.to_glwe_size(), d_bsk.polynomial_size)
}
CudaBootstrappingKey::MultiBit(d_bsk) => {
(d_bsk.glwe_dimension.to_glwe_size(), d_bsk.polynomial_size)
}
};
let mut acc = GlweCiphertext::new(0, glwe_size, polynomial_size, self.ciphertext_modulus);
let (input_max_degree, sample_extraction_stride, per_function_output_degree) =
fill_many_lut_accumulator(
&mut acc,
polynomial_size,
glwe_size,
self.message_modulus,
self.carry_modulus,
&f,
);
ManyLookupTableOwned {
acc,
input_max_degree,
sample_extraction_stride,
per_function_output_degree,
}
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must

File diff suppressed because it is too large Load Diff

View File

@@ -17,7 +17,7 @@ use std::ops::Range;
/// Input values are not required to span all possible values that
/// ` ct` could hold.
#[derive(Debug)]
pub struct MatchValues<Clear>(Vec<(Clear, Clear)>);
pub struct MatchValues<Clear>(pub Vec<(Clear, Clear)>);
impl<Clear> MatchValues<Clear> {
/// Builds a `MatchValues` from a Vec of tuple where in each tuple element,

View File

@@ -160,32 +160,30 @@ pub(crate) fn fill_accumulator_no_encoding<F, C>(
}
/// Fills a GlweCiphertext for use in a ManyLookupTable setting
pub(crate) fn fill_many_lut_accumulator<C>(
pub(crate) fn fill_many_lut_accumulator<F, C>(
accumulator: &mut GlweCiphertext<C>,
server_key: &ServerKey,
functions: &[&dyn Fn(u64) -> u64],
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
functions: &[F],
) -> (MaxDegree, usize, Vec<Degree>)
where
C: ContainerMut<Element = u64>,
F: Fn(u64) -> u64,
{
assert_eq!(
accumulator.polynomial_size(),
server_key.bootstrapping_key.polynomial_size()
);
assert_eq!(
accumulator.glwe_size(),
server_key.bootstrapping_key.glwe_size()
);
assert_eq!(accumulator.polynomial_size(), polynomial_size);
assert_eq!(accumulator.glwe_size(), glwe_size);
let mut accumulator_view = accumulator.as_mut_view();
accumulator_view.get_mut_mask().as_mut().fill(0);
// Modulus of the msg contained in the msg bits and operations buffer
let modulus_sup = server_key.message_modulus.0 * server_key.carry_modulus.0;
let modulus_sup = message_modulus.0 * carry_modulus.0;
// N/(p/2) = size of each block
let box_size = server_key.bootstrapping_key.polynomial_size().0 / modulus_sup;
let box_size = polynomial_size.0 / modulus_sup;
// Value of the delta we multiply our messages by
let delta = (1_u64 << 63) / (modulus_sup as u64);

View File

@@ -883,7 +883,14 @@ impl ServerKey {
self.ciphertext_modulus,
);
let (input_max_degree, sample_extraction_stride, per_function_output_degree) =
fill_many_lut_accumulator(&mut acc, self, functions);
fill_many_lut_accumulator(
&mut acc,
self.bootstrapping_key.polynomial_size(),
self.bootstrapping_key.glwe_size(),
self.message_modulus,
self.carry_modulus,
functions,
);
ManyLookupTableOwned {
acc,