mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
1 Commits
tfhe-rs-1.
...
multibit-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
69b37f3b06 |
@@ -524,6 +524,257 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
});
|
||||
}
|
||||
|
||||
pub fn multi_bit_blind_rotate_assign_deterministic<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
input: &LweCiphertext<InputCont>,
|
||||
lut: &mut GlweCiphertext<OutputCont>,
|
||||
multi_bit_bsk: &FourierLweMultiBitBootstrapKey<KeyCont>,
|
||||
thread_count: ThreadCount,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
KeyCont: Container<Element = c64> + Sync,
|
||||
{
|
||||
let (lwe_mask, lwe_body) = input.get_mask_and_body();
|
||||
|
||||
// No way to chunk the result of ggsw_iter at the moment
|
||||
let ggsw_vec: Vec<_> = multi_bit_bsk.ggsw_iter().collect();
|
||||
let mut work_queue = Vec::with_capacity(multi_bit_bsk.multi_bit_input_lwe_dimension().0);
|
||||
|
||||
let grouping_factor = multi_bit_bsk.grouping_factor();
|
||||
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
|
||||
|
||||
for (task_idx, (lwe_mask_elements, ggsw_bunch)) in lwe_mask
|
||||
.as_ref()
|
||||
.chunks_exact(grouping_factor.0)
|
||||
.zip(ggsw_vec.chunks_exact(ggsw_per_multi_bit_element.0))
|
||||
.enumerate()
|
||||
{
|
||||
work_queue.push((task_idx, lwe_mask_elements, ggsw_bunch));
|
||||
}
|
||||
|
||||
assert!(work_queue.len() == lwe_mask.lwe_dimension().0 / grouping_factor.0);
|
||||
|
||||
let work_queue = &work_queue;
|
||||
|
||||
// At any point in time the idea is to have half the buffer already processed ready for the
|
||||
// consumer and the other half being filled by the producer threads.
|
||||
let ring_buffer_size: usize = thread_count.0;
|
||||
|
||||
let lut_poly_size = lut.polynomial_size();
|
||||
let monomial_degree = pbs_modulus_switch(
|
||||
*lwe_body.data,
|
||||
lut_poly_size,
|
||||
ModulusSwitchOffset(0),
|
||||
LutCountLog(0),
|
||||
);
|
||||
|
||||
// Modulus switching
|
||||
lut.as_mut_polynomial_list()
|
||||
.iter_mut()
|
||||
.for_each(|mut poly| {
|
||||
polynomial_wrapping_monic_monomial_div_assign(
|
||||
&mut poly,
|
||||
MonomialDegree(monomial_degree),
|
||||
)
|
||||
});
|
||||
|
||||
let fourier_multi_bit_ggsw_buffers = (0..ring_buffer_size)
|
||||
.map(|_| {
|
||||
(
|
||||
Mutex::new(false),
|
||||
Condvar::new(),
|
||||
Mutex::new(FourierGgswCiphertext::new(
|
||||
multi_bit_bsk.glwe_size(),
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
multi_bit_bsk.decomposition_base_log(),
|
||||
multi_bit_bsk.decomposition_level_count(),
|
||||
)),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
thread::scope(|s| {
|
||||
let produce_multi_bit_fourier_ggsw = |thread_idx| {
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
let fft = Fft::new(multi_bit_bsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
|
||||
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
|
||||
|
||||
let mut unit_polynomial =
|
||||
Polynomial::new(Scalar::ZERO, multi_bit_bsk.polynomial_size());
|
||||
unit_polynomial.as_mut()[0] = Scalar::ONE;
|
||||
let mut a_monomial = unit_polynomial.clone();
|
||||
let mut fourier_a_monomial = FourierPolynomial::new(multi_bit_bsk.polynomial_size());
|
||||
|
||||
let fourier_multi_bit_ggsw_buffers = &fourier_multi_bit_ggsw_buffers;
|
||||
|
||||
for (task_idx, lwe_mask_elements, ggsw_bunch) in work_queue
|
||||
.iter()
|
||||
.skip(thread_idx)
|
||||
.step_by(thread_count.0)
|
||||
.copied()
|
||||
{
|
||||
let dest_idx = task_idx % ring_buffer_size;
|
||||
let (ready_for_consumer_lock, condvar, fourier_ggsw_buffer) =
|
||||
&fourier_multi_bit_ggsw_buffers[dest_idx];
|
||||
|
||||
let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();
|
||||
|
||||
// Wait while the buffer is not ready for processing and wait on the condvar to
|
||||
// get notified when we can start processing again
|
||||
while *ready_for_consumer {
|
||||
ready_for_consumer = condvar.wait(ready_for_consumer).unwrap();
|
||||
}
|
||||
|
||||
let mut fourier_ggsw_buffer = fourier_ggsw_buffer.lock().unwrap();
|
||||
|
||||
let mut bunch_iter = ggsw_bunch.iter();
|
||||
|
||||
// Keygen guarantees the first term is a constant term of the polynomial, no
|
||||
// polynomial multiplication required
|
||||
let ggsw_a_none = bunch_iter.next().unwrap();
|
||||
|
||||
fourier_ggsw_buffer
|
||||
.as_mut_view()
|
||||
.data()
|
||||
.copy_from_slice(ggsw_a_none.as_view().data());
|
||||
|
||||
let multi_bit_fourier_ggsw = fourier_ggsw_buffer.as_mut_view().data();
|
||||
|
||||
for (ggsw_idx, fourier_ggsw) in bunch_iter.enumerate() {
|
||||
// We already processed the first ggsw, advance the index by 1
|
||||
let ggsw_idx = ggsw_idx + 1;
|
||||
|
||||
let mut monomial_degree = Scalar::ZERO;
|
||||
for (mask_idx, &mask_element) in lwe_mask_elements.iter().enumerate() {
|
||||
let mask_position = lwe_mask_elements.len() - (mask_idx + 1);
|
||||
let selection_bit: Scalar =
|
||||
Scalar::cast_from((ggsw_idx >> mask_position) & 1);
|
||||
monomial_degree =
|
||||
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
|
||||
}
|
||||
|
||||
let switched_degree = pbs_modulus_switch(
|
||||
monomial_degree,
|
||||
lut_poly_size,
|
||||
ModulusSwitchOffset(0),
|
||||
LutCountLog(0),
|
||||
);
|
||||
|
||||
a_monomial
|
||||
.as_mut()
|
||||
.copy_from_slice(unit_polynomial.as_ref());
|
||||
polynomial_wrapping_monic_monomial_mul_assign(
|
||||
&mut a_monomial,
|
||||
MonomialDegree(switched_degree),
|
||||
);
|
||||
|
||||
fft.forward_as_integer(
|
||||
fourier_a_monomial.as_mut_view(),
|
||||
a_monomial.as_view(),
|
||||
buffers.stack(),
|
||||
);
|
||||
|
||||
update_with_fmadd(
|
||||
multi_bit_fourier_ggsw,
|
||||
fourier_ggsw.as_view().data(),
|
||||
fourier_a_monomial.as_view().data,
|
||||
false,
|
||||
lut_poly_size.to_fourier_polynomial_size().0,
|
||||
);
|
||||
}
|
||||
|
||||
// Drop the lock before we wake other threads
|
||||
drop(fourier_ggsw_buffer);
|
||||
|
||||
*ready_for_consumer = true;
|
||||
|
||||
// Wake threads waiting on the condvar
|
||||
condvar.notify_all();
|
||||
}
|
||||
};
|
||||
|
||||
let threads: Vec<_> = (0..thread_count.0)
|
||||
.map(|idx| s.spawn(move || produce_multi_bit_fourier_ggsw(idx)))
|
||||
.collect();
|
||||
|
||||
// We initialize ct0 for the successive external products
|
||||
let ct0 = lut;
|
||||
let mut ct1 = GlweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
ct0.glwe_size(),
|
||||
ct0.polynomial_size(),
|
||||
ct0.ciphertext_modulus(),
|
||||
);
|
||||
let ct1 = &mut ct1;
|
||||
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
let fft = Fft::new(multi_bit_bsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
|
||||
buffers.resize(
|
||||
add_external_product_assign_scratch::<Scalar>(
|
||||
multi_bit_bsk.glwe_size(),
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
let mut src_idx = 1usize;
|
||||
|
||||
for (ready_lock, condvar, multi_bit_fourier_ggsw) in fourier_multi_bit_ggsw_buffers
|
||||
.iter()
|
||||
.cycle()
|
||||
.take(multi_bit_bsk.multi_bit_input_lwe_dimension().0)
|
||||
{
|
||||
src_idx ^= 1;
|
||||
|
||||
let (src_ct, mut dst_ct) = if src_idx == 0 {
|
||||
(ct0.as_view(), ct1.as_mut_view())
|
||||
} else {
|
||||
(ct1.as_view(), ct0.as_mut_view())
|
||||
};
|
||||
|
||||
dst_ct.as_mut().fill(Scalar::ZERO);
|
||||
|
||||
let mut ready = ready_lock.lock().unwrap();
|
||||
|
||||
while !*ready {
|
||||
ready = condvar.wait(ready).unwrap();
|
||||
}
|
||||
|
||||
let multi_bit_fourier_ggsw = multi_bit_fourier_ggsw.lock().unwrap();
|
||||
|
||||
add_external_product_assign(
|
||||
dst_ct,
|
||||
multi_bit_fourier_ggsw.as_view(),
|
||||
src_ct,
|
||||
fft,
|
||||
buffers.stack(),
|
||||
);
|
||||
|
||||
*ready = false;
|
||||
|
||||
// Wake a single producer thread sleeping on the condvar (only one will get to work
|
||||
// anyways)
|
||||
condvar.notify_one();
|
||||
}
|
||||
|
||||
if src_idx == 0 {
|
||||
ct0.as_mut().copy_from_slice(ct1.as_ref());
|
||||
}
|
||||
|
||||
threads.into_iter().for_each(|t| t.join().unwrap());
|
||||
});
|
||||
}
|
||||
|
||||
/// Perform a programmable bootstrap with given an input [`LWE ciphertext`](`LweCiphertext`), a
|
||||
/// look-up table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE multi-bit bootstrap
|
||||
/// key`](`LweMultiBitBootstrapKey`) in the fourier domain. The result is written in the provided
|
||||
@@ -743,7 +994,7 @@ pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
|
||||
thread_count: ThreadCount,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync + Send,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
AccCont: Container<Element = Scalar>,
|
||||
@@ -811,7 +1062,12 @@ pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
|
||||
.as_mut()
|
||||
.copy_from_slice(accumulator.as_ref());
|
||||
|
||||
multi_bit_blind_rotate_assign(input, &mut local_accumulator, multi_bit_bsk, thread_count);
|
||||
multi_bit_blind_rotate_assign_deterministic(
|
||||
input,
|
||||
&mut local_accumulator,
|
||||
multi_bit_bsk,
|
||||
thread_count,
|
||||
);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user