Compare commits

...

1 Commits

Author SHA1 Message Date
sarah el kazdadi
69b37f3b06 feat(multibit): implement deterministic multibit pbs 2023-05-15 15:42:12 +02:00

View File

@@ -524,6 +524,257 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
});
}
pub fn multi_bit_blind_rotate_assign_deterministic<Scalar, InputCont, OutputCont, KeyCont>(
input: &LweCiphertext<InputCont>,
lut: &mut GlweCiphertext<OutputCont>,
multi_bit_bsk: &FourierLweMultiBitBootstrapKey<KeyCont>,
thread_count: ThreadCount,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
KeyCont: Container<Element = c64> + Sync,
{
let (lwe_mask, lwe_body) = input.get_mask_and_body();
// No way to chunk the result of ggsw_iter at the moment
let ggsw_vec: Vec<_> = multi_bit_bsk.ggsw_iter().collect();
let mut work_queue = Vec::with_capacity(multi_bit_bsk.multi_bit_input_lwe_dimension().0);
let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
for (task_idx, (lwe_mask_elements, ggsw_bunch)) in lwe_mask
.as_ref()
.chunks_exact(grouping_factor.0)
.zip(ggsw_vec.chunks_exact(ggsw_per_multi_bit_element.0))
.enumerate()
{
work_queue.push((task_idx, lwe_mask_elements, ggsw_bunch));
}
assert!(work_queue.len() == lwe_mask.lwe_dimension().0 / grouping_factor.0);
let work_queue = &work_queue;
// At any point in time the idea is to have half the buffer already processed ready for the
// consumer and the other half being filled by the producer threads.
let ring_buffer_size: usize = thread_count.0;
let lut_poly_size = lut.polynomial_size();
let monomial_degree = pbs_modulus_switch(
*lwe_body.data,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
// Modulus switching
lut.as_mut_polynomial_list()
.iter_mut()
.for_each(|mut poly| {
polynomial_wrapping_monic_monomial_div_assign(
&mut poly,
MonomialDegree(monomial_degree),
)
});
let fourier_multi_bit_ggsw_buffers = (0..ring_buffer_size)
.map(|_| {
(
Mutex::new(false),
Condvar::new(),
Mutex::new(FourierGgswCiphertext::new(
multi_bit_bsk.glwe_size(),
multi_bit_bsk.polynomial_size(),
multi_bit_bsk.decomposition_base_log(),
multi_bit_bsk.decomposition_level_count(),
)),
)
})
.collect::<Vec<_>>();
thread::scope(|s| {
let produce_multi_bit_fourier_ggsw = |thread_idx| {
let mut buffers = ComputationBuffers::new();
let fft = Fft::new(multi_bit_bsk.polynomial_size());
let fft = fft.as_view();
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
let mut unit_polynomial =
Polynomial::new(Scalar::ZERO, multi_bit_bsk.polynomial_size());
unit_polynomial.as_mut()[0] = Scalar::ONE;
let mut a_monomial = unit_polynomial.clone();
let mut fourier_a_monomial = FourierPolynomial::new(multi_bit_bsk.polynomial_size());
let fourier_multi_bit_ggsw_buffers = &fourier_multi_bit_ggsw_buffers;
for (task_idx, lwe_mask_elements, ggsw_bunch) in work_queue
.iter()
.skip(thread_idx)
.step_by(thread_count.0)
.copied()
{
let dest_idx = task_idx % ring_buffer_size;
let (ready_for_consumer_lock, condvar, fourier_ggsw_buffer) =
&fourier_multi_bit_ggsw_buffers[dest_idx];
let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();
// Wait while the buffer is not ready for processing and wait on the condvar to
// get notified when we can start processing again
while *ready_for_consumer {
ready_for_consumer = condvar.wait(ready_for_consumer).unwrap();
}
let mut fourier_ggsw_buffer = fourier_ggsw_buffer.lock().unwrap();
let mut bunch_iter = ggsw_bunch.iter();
// Keygen guarantees the first term is a constant term of the polynomial, no
// polynomial multiplication required
let ggsw_a_none = bunch_iter.next().unwrap();
fourier_ggsw_buffer
.as_mut_view()
.data()
.copy_from_slice(ggsw_a_none.as_view().data());
let multi_bit_fourier_ggsw = fourier_ggsw_buffer.as_mut_view().data();
for (ggsw_idx, fourier_ggsw) in bunch_iter.enumerate() {
// We already processed the first ggsw, advance the index by 1
let ggsw_idx = ggsw_idx + 1;
let mut monomial_degree = Scalar::ZERO;
for (mask_idx, &mask_element) in lwe_mask_elements.iter().enumerate() {
let mask_position = lwe_mask_elements.len() - (mask_idx + 1);
let selection_bit: Scalar =
Scalar::cast_from((ggsw_idx >> mask_position) & 1);
monomial_degree =
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
}
let switched_degree = pbs_modulus_switch(
monomial_degree,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
a_monomial
.as_mut()
.copy_from_slice(unit_polynomial.as_ref());
polynomial_wrapping_monic_monomial_mul_assign(
&mut a_monomial,
MonomialDegree(switched_degree),
);
fft.forward_as_integer(
fourier_a_monomial.as_mut_view(),
a_monomial.as_view(),
buffers.stack(),
);
update_with_fmadd(
multi_bit_fourier_ggsw,
fourier_ggsw.as_view().data(),
fourier_a_monomial.as_view().data,
false,
lut_poly_size.to_fourier_polynomial_size().0,
);
}
// Drop the lock before we wake other threads
drop(fourier_ggsw_buffer);
*ready_for_consumer = true;
// Wake threads waiting on the condvar
condvar.notify_all();
}
};
let threads: Vec<_> = (0..thread_count.0)
.map(|idx| s.spawn(move || produce_multi_bit_fourier_ggsw(idx)))
.collect();
// We initialize ct0 for the successive external products
let ct0 = lut;
let mut ct1 = GlweCiphertext::new(
Scalar::ZERO,
ct0.glwe_size(),
ct0.polynomial_size(),
ct0.ciphertext_modulus(),
);
let ct1 = &mut ct1;
let mut buffers = ComputationBuffers::new();
let fft = Fft::new(multi_bit_bsk.polynomial_size());
let fft = fft.as_view();
buffers.resize(
add_external_product_assign_scratch::<Scalar>(
multi_bit_bsk.glwe_size(),
multi_bit_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let mut src_idx = 1usize;
for (ready_lock, condvar, multi_bit_fourier_ggsw) in fourier_multi_bit_ggsw_buffers
.iter()
.cycle()
.take(multi_bit_bsk.multi_bit_input_lwe_dimension().0)
{
src_idx ^= 1;
let (src_ct, mut dst_ct) = if src_idx == 0 {
(ct0.as_view(), ct1.as_mut_view())
} else {
(ct1.as_view(), ct0.as_mut_view())
};
dst_ct.as_mut().fill(Scalar::ZERO);
let mut ready = ready_lock.lock().unwrap();
while !*ready {
ready = condvar.wait(ready).unwrap();
}
let multi_bit_fourier_ggsw = multi_bit_fourier_ggsw.lock().unwrap();
add_external_product_assign(
dst_ct,
multi_bit_fourier_ggsw.as_view(),
src_ct,
fft,
buffers.stack(),
);
*ready = false;
// Wake a single producer thread sleeping on the condvar (only one will get to work
// anyways)
condvar.notify_one();
}
if src_idx == 0 {
ct0.as_mut().copy_from_slice(ct1.as_ref());
}
threads.into_iter().for_each(|t| t.join().unwrap());
});
}
/// Perform a programmable bootstrap with given an input [`LWE ciphertext`](`LweCiphertext`), a
/// look-up table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE multi-bit bootstrap
/// key`](`LweMultiBitBootstrapKey`) in the fourier domain. The result is written in the provided
@@ -743,7 +994,7 @@ pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
thread_count: ThreadCount,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync + Send,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
AccCont: Container<Element = Scalar>,
@@ -811,7 +1062,12 @@ pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
.as_mut()
.copy_from_slice(accumulator.as_ref());
multi_bit_blind_rotate_assign(input, &mut local_accumulator, multi_bit_bsk, thread_count);
multi_bit_blind_rotate_assign_deterministic(
input,
&mut local_accumulator,
multi_bit_bsk,
thread_count,
);
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
}