refactor(fft): update fft code to use FourierPolynomialSize

This commit is contained in:
Arthur Meyre
2023-02-17 14:33:26 +01:00
parent d3b3c5ab21
commit bf6f699e8c
6 changed files with 70 additions and 30 deletions

View File

@@ -111,6 +111,28 @@ impl PolynomialSize {
pub fn log2(&self) -> PolynomialSizeLog {
PolynomialSizeLog((self.0 as f64).log2().ceil() as usize)
}
pub fn to_fourier_polynomial_size(&self) -> FourierPolynomialSize {
assert_eq!(
self.0 % 2,
0,
"Cannot convert a PolynomialSize that is not a multiple of 2 to FourierPolynomialSize"
);
FourierPolynomialSize(self.0 / 2)
}
}
/// The number of elements in the container of a fourier polynomial.
///
/// Assuming a standard polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this returns
/// $\frac{N}{2}$.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct FourierPolynomialSize(pub usize);
impl FourierPolynomialSize {
pub fn to_standard_polynomial_size(&self) -> PolynomialSize {
PolynomialSize(self.0 * 2)
}
}
/// The logarithm of the number of coefficients of a polynomial.

View File

@@ -39,10 +39,10 @@ impl<C: Container<Element = c64>> FourierLweBootstrapKey<C> {
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
) -> Self {
assert_eq!(polynomial_size.0 % 2, 0);
assert_eq!(
data.container_len(),
input_lwe_dimension.0 * polynomial_size.0 / 2
input_lwe_dimension.0
* polynomial_size.to_fourier_polynomial_size().0
* decomposition_level_count.0
* glwe_size.0
* glwe_size.0
@@ -148,12 +148,11 @@ impl FourierLweBootstrapKey<ABox<[c64]>> {
) -> FourierLweBootstrapKey<ABox<[c64]>> {
let boxed = avec![
c64::default();
polynomial_size.0
polynomial_size.to_fourier_polynomial_size().0
* input_lwe_dimension.0
* decomposition_level_count.0
* glwe_size.0
* glwe_size.0
/ 2
]
.into_boxed_slice();

View File

@@ -67,10 +67,12 @@ impl<C: Container<Element = c64>> FourierGgswCiphertext<C> {
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
) -> Self {
assert_eq!(polynomial_size.0 % 2, 0);
assert_eq!(
data.container_len(),
polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0
polynomial_size.to_fourier_polynomial_size().0
* glwe_size.0
* glwe_size.0
* decomposition_level_count.0
);
Self {
@@ -143,10 +145,9 @@ impl<C: Container<Element = c64>> FourierGgswLevelMatrix<C> {
row_count: usize,
decomposition_level: DecompositionLevel,
) -> Self {
assert_eq!(polynomial_size.0 % 2, 0);
assert_eq!(
data.container_len(),
polynomial_size.0 / 2 * glwe_size.0 * row_count
polynomial_size.to_fourier_polynomial_size().0 * glwe_size.0 * row_count
);
Self {
data,
@@ -200,8 +201,10 @@ impl<C: Container<Element = c64>> FourierGgswLevelRow<C> {
polynomial_size: PolynomialSize,
decomposition_level: DecompositionLevel,
) -> Self {
assert_eq!(polynomial_size.0 % 2, 0);
assert_eq!(data.container_len(), polynomial_size.0 / 2 * glwe_size.0);
assert_eq!(
data.container_len(),
polynomial_size.to_fourier_polynomial_size().0 * glwe_size.0
);
Self {
data,
polynomial_size,
@@ -261,10 +264,10 @@ impl<'a> FourierGgswCiphertextMutView<'a> {
mut stack: DynStack<'_>,
) {
debug_assert_eq!(coef_ggsw.polynomial_size(), self.polynomial_size());
let poly_size = coef_ggsw.polynomial_size().0;
let fourier_poly_size = coef_ggsw.polynomial_size().to_fourier_polynomial_size().0;
for (fourier_poly, coef_poly) in izip!(
self.data().into_chunks(poly_size / 2),
self.data().into_chunks(fourier_poly_size),
coef_ggsw.as_polynomial_list().iter()
) {
// SAFETY: forward_as_torus doesn't write any uninitialized values into its output
@@ -291,7 +294,10 @@ impl FourierGgswCiphertext<ABox<[c64]>> {
) -> FourierGgswCiphertext<ABox<[c64]>> {
let boxed = avec![
c64::default();
polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0
polynomial_size.to_fourier_polynomial_size().0
* glwe_size.0
* glwe_size.0
* decomposition_level_count.0
]
.into_boxed_slice();
@@ -314,9 +320,10 @@ pub fn add_external_product_assign_scratch<Scalar>(
let align = CACHELINE_ALIGN;
let standard_scratch =
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align)?;
let fourier_polynomial_size = polynomial_size.to_fourier_polynomial_size().0;
let fourier_scratch =
StackReq::try_new_aligned::<c64>(glwe_size.0 * polynomial_size.0 / 2, align)?;
let fourier_scratch_single = StackReq::try_new_aligned::<c64>(polynomial_size.0 / 2, align)?;
StackReq::try_new_aligned::<c64>(glwe_size.0 * fourier_polynomial_size, align)?;
let fourier_scratch_single = StackReq::try_new_aligned::<c64>(fourier_polynomial_size, align)?;
let substack3 = fft.forward_scratch()?;
let substack2 = substack3.try_and(fourier_scratch_single)?;
@@ -348,7 +355,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
debug_assert_eq!(ggsw.glwe_size(), out.glwe_size());
let align = CACHELINE_ALIGN;
let poly_size = ggsw.polynomial_size().0;
let fourier_poly_size = ggsw.polynomial_size().to_fourier_polynomial_size().0;
// we round the input mask and body
let decomposer = SignedDecomposer::<Scalar>::new(
@@ -357,7 +364,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
);
let (mut output_fft_buffer, mut substack0) =
stack.make_aligned_uninit::<c64>(poly_size / 2 * ggsw.glwe_size().0, align);
stack.make_aligned_uninit::<c64>(fourier_poly_size * ggsw.glwe_size().0, align);
// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
// the cost of filling it up with zeros. `is_output_uninit` is set to `false` once
// it has been fully initialized for the first time.
@@ -405,7 +412,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
.for_each(|(ggsw_row, glwe_poly)| {
let (mut fourier, substack3) = substack2
.rb_mut()
.make_aligned_uninit::<c64>(poly_size / 2, align);
.make_aligned_uninit::<c64>(fourier_poly_size, align);
// We perform the forward fft transform for the glwe polynomial
let fourier = fft
.forward_as_integer(
@@ -424,7 +431,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
ggsw_row,
fourier,
is_output_uninit,
poly_size,
fourier_poly_size,
)
};
@@ -445,7 +452,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
izip!(
out.as_mut_polynomial_list().iter_mut(),
output_fft_buffer
.into_chunks(poly_size / 2)
.into_chunks(fourier_poly_size)
.map(|slice| FourierPolynomialView { data: slice }),
)
.for_each(|(out, fourier)| {
@@ -635,7 +642,7 @@ unsafe fn update_with_fmadd(
ggsw_row: FourierGgswLevelRowView,
fourier: &[c64],
is_output_uninit: bool,
poly_size: usize,
fourier_poly_size: usize,
) {
#[allow(clippy::type_complexity)]
let ptr_fn = || -> unsafe fn(&mut [MaybeUninit<c64>], &[c64], &[c64], bool) {
@@ -657,8 +664,8 @@ unsafe fn update_with_fmadd(
let ptr = ptr_fn();
izip!(
output_fft_buffer.into_chunks(poly_size / 2),
ggsw_row.data.into_chunks(poly_size / 2)
output_fft_buffer.into_chunks(fourier_poly_size),
ggsw_row.data.into_chunks(fourier_poly_size)
)
.for_each(|(output_fourier, ggsw_poly)| {
ptr(output_fourier, ggsw_poly, fourier, is_output_uninit);

View File

@@ -490,10 +490,13 @@ impl<C: Container<Element = c64>> FourierGgswCiphertextList<C> {
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
) -> Self {
assert_eq!(polynomial_size.0 % 2, 0);
assert_eq!(
data.container_len(),
count * polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0
count
* polynomial_size.to_fourier_polynomial_size().0
* glwe_size.0
* glwe_size.0
* decomposition_level_count.0
);
Self {
@@ -588,7 +591,10 @@ impl<C: Container<Element = c64>> FourierGgswCiphertextList<C> {
let decomposition_base_log = self.decomposition_base_log;
let (left, right) = self.fourier.data.split_at(
mid * polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0,
mid * polynomial_size.to_fourier_polynomial_size().0
* glwe_size.0
* glwe_size.0
* decomposition_level_count.0,
);
(
Self::new(

View File

@@ -434,7 +434,11 @@ pub fn test_cmux_tree() {
let mut ggsw_list = FourierGgswCiphertextList::new(
vec![
c64::default();
nb_ggsw * polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * level.0
nb_ggsw
* polynomial_size.to_fourier_polynomial_size().0
* glwe_size.0
* glwe_size.0
* level.0
],
nb_ggsw,
glwe_size,

View File

@@ -548,7 +548,7 @@ impl<C: Container<Element = c64>> serde::Serialize for FourierPolynomialList<C>
let chunk_count = if polynomial_size.0 == 0 {
0
} else {
data.len() / (polynomial_size.0 / 2)
data.len() / (polynomial_size.to_fourier_polynomial_size().0)
};
let mut state = serializer.serialize_seq(Some(2 + chunk_count))?;
@@ -618,8 +618,10 @@ impl<'de, C: IntoContainerOwned<Element = c64>> serde::Deserialize<'de>
}
}
let mut data =
C::collect((0..(polynomial_size.0 / 2 * chunk_count)).map(|_| c64::default()));
let mut data = C::collect(
(0..(polynomial_size.to_fourier_polynomial_size().0 * chunk_count))
.map(|_| c64::default()),
);
if chunk_count != 0 {
let fft = Fft::new(polynomial_size);