mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
refactor(fft): update fft code to use FourierPolynomialSize
This commit is contained in:
@@ -111,6 +111,28 @@ impl PolynomialSize {
|
||||
pub fn log2(&self) -> PolynomialSizeLog {
|
||||
PolynomialSizeLog((self.0 as f64).log2().ceil() as usize)
|
||||
}
|
||||
|
||||
pub fn to_fourier_polynomial_size(&self) -> FourierPolynomialSize {
|
||||
assert_eq!(
|
||||
self.0 % 2,
|
||||
0,
|
||||
"Cannot convert a PolynomialSize that is not a multiple of 2 to FourierPolynomialSize"
|
||||
);
|
||||
FourierPolynomialSize(self.0 / 2)
|
||||
}
|
||||
}
|
||||
|
||||
/// The number of elements in the container of a fourier polynomial.
|
||||
///
|
||||
/// Assuming a standard polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this returns
|
||||
/// $\frac{N}{2}$.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct FourierPolynomialSize(pub usize);
|
||||
|
||||
impl FourierPolynomialSize {
|
||||
pub fn to_standard_polynomial_size(&self) -> PolynomialSize {
|
||||
PolynomialSize(self.0 * 2)
|
||||
}
|
||||
}
|
||||
|
||||
/// The logarithm of the number of coefficients of a polynomial.
|
||||
|
||||
@@ -39,10 +39,10 @@ impl<C: Container<Element = c64>> FourierLweBootstrapKey<C> {
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
) -> Self {
|
||||
assert_eq!(polynomial_size.0 % 2, 0);
|
||||
assert_eq!(
|
||||
data.container_len(),
|
||||
input_lwe_dimension.0 * polynomial_size.0 / 2
|
||||
input_lwe_dimension.0
|
||||
* polynomial_size.to_fourier_polynomial_size().0
|
||||
* decomposition_level_count.0
|
||||
* glwe_size.0
|
||||
* glwe_size.0
|
||||
@@ -148,12 +148,11 @@ impl FourierLweBootstrapKey<ABox<[c64]>> {
|
||||
) -> FourierLweBootstrapKey<ABox<[c64]>> {
|
||||
let boxed = avec![
|
||||
c64::default();
|
||||
polynomial_size.0
|
||||
polynomial_size.to_fourier_polynomial_size().0
|
||||
* input_lwe_dimension.0
|
||||
* decomposition_level_count.0
|
||||
* glwe_size.0
|
||||
* glwe_size.0
|
||||
/ 2
|
||||
]
|
||||
.into_boxed_slice();
|
||||
|
||||
|
||||
@@ -67,10 +67,12 @@ impl<C: Container<Element = c64>> FourierGgswCiphertext<C> {
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
) -> Self {
|
||||
assert_eq!(polynomial_size.0 % 2, 0);
|
||||
assert_eq!(
|
||||
data.container_len(),
|
||||
polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0
|
||||
polynomial_size.to_fourier_polynomial_size().0
|
||||
* glwe_size.0
|
||||
* glwe_size.0
|
||||
* decomposition_level_count.0
|
||||
);
|
||||
|
||||
Self {
|
||||
@@ -143,10 +145,9 @@ impl<C: Container<Element = c64>> FourierGgswLevelMatrix<C> {
|
||||
row_count: usize,
|
||||
decomposition_level: DecompositionLevel,
|
||||
) -> Self {
|
||||
assert_eq!(polynomial_size.0 % 2, 0);
|
||||
assert_eq!(
|
||||
data.container_len(),
|
||||
polynomial_size.0 / 2 * glwe_size.0 * row_count
|
||||
polynomial_size.to_fourier_polynomial_size().0 * glwe_size.0 * row_count
|
||||
);
|
||||
Self {
|
||||
data,
|
||||
@@ -200,8 +201,10 @@ impl<C: Container<Element = c64>> FourierGgswLevelRow<C> {
|
||||
polynomial_size: PolynomialSize,
|
||||
decomposition_level: DecompositionLevel,
|
||||
) -> Self {
|
||||
assert_eq!(polynomial_size.0 % 2, 0);
|
||||
assert_eq!(data.container_len(), polynomial_size.0 / 2 * glwe_size.0);
|
||||
assert_eq!(
|
||||
data.container_len(),
|
||||
polynomial_size.to_fourier_polynomial_size().0 * glwe_size.0
|
||||
);
|
||||
Self {
|
||||
data,
|
||||
polynomial_size,
|
||||
@@ -261,10 +264,10 @@ impl<'a> FourierGgswCiphertextMutView<'a> {
|
||||
mut stack: DynStack<'_>,
|
||||
) {
|
||||
debug_assert_eq!(coef_ggsw.polynomial_size(), self.polynomial_size());
|
||||
let poly_size = coef_ggsw.polynomial_size().0;
|
||||
let fourier_poly_size = coef_ggsw.polynomial_size().to_fourier_polynomial_size().0;
|
||||
|
||||
for (fourier_poly, coef_poly) in izip!(
|
||||
self.data().into_chunks(poly_size / 2),
|
||||
self.data().into_chunks(fourier_poly_size),
|
||||
coef_ggsw.as_polynomial_list().iter()
|
||||
) {
|
||||
// SAFETY: forward_as_torus doesn't write any uninitialized values into its output
|
||||
@@ -291,7 +294,10 @@ impl FourierGgswCiphertext<ABox<[c64]>> {
|
||||
) -> FourierGgswCiphertext<ABox<[c64]>> {
|
||||
let boxed = avec![
|
||||
c64::default();
|
||||
polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0
|
||||
polynomial_size.to_fourier_polynomial_size().0
|
||||
* glwe_size.0
|
||||
* glwe_size.0
|
||||
* decomposition_level_count.0
|
||||
]
|
||||
.into_boxed_slice();
|
||||
|
||||
@@ -314,9 +320,10 @@ pub fn add_external_product_assign_scratch<Scalar>(
|
||||
let align = CACHELINE_ALIGN;
|
||||
let standard_scratch =
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let fourier_polynomial_size = polynomial_size.to_fourier_polynomial_size().0;
|
||||
let fourier_scratch =
|
||||
StackReq::try_new_aligned::<c64>(glwe_size.0 * polynomial_size.0 / 2, align)?;
|
||||
let fourier_scratch_single = StackReq::try_new_aligned::<c64>(polynomial_size.0 / 2, align)?;
|
||||
StackReq::try_new_aligned::<c64>(glwe_size.0 * fourier_polynomial_size, align)?;
|
||||
let fourier_scratch_single = StackReq::try_new_aligned::<c64>(fourier_polynomial_size, align)?;
|
||||
|
||||
let substack3 = fft.forward_scratch()?;
|
||||
let substack2 = substack3.try_and(fourier_scratch_single)?;
|
||||
@@ -348,7 +355,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
|
||||
debug_assert_eq!(ggsw.glwe_size(), out.glwe_size());
|
||||
|
||||
let align = CACHELINE_ALIGN;
|
||||
let poly_size = ggsw.polynomial_size().0;
|
||||
let fourier_poly_size = ggsw.polynomial_size().to_fourier_polynomial_size().0;
|
||||
|
||||
// we round the input mask and body
|
||||
let decomposer = SignedDecomposer::<Scalar>::new(
|
||||
@@ -357,7 +364,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
|
||||
);
|
||||
|
||||
let (mut output_fft_buffer, mut substack0) =
|
||||
stack.make_aligned_uninit::<c64>(poly_size / 2 * ggsw.glwe_size().0, align);
|
||||
stack.make_aligned_uninit::<c64>(fourier_poly_size * ggsw.glwe_size().0, align);
|
||||
// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
|
||||
// the cost of filling it up with zeros. `is_output_uninit` is set to `false` once
|
||||
// it has been fully initialized for the first time.
|
||||
@@ -405,7 +412,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
|
||||
.for_each(|(ggsw_row, glwe_poly)| {
|
||||
let (mut fourier, substack3) = substack2
|
||||
.rb_mut()
|
||||
.make_aligned_uninit::<c64>(poly_size / 2, align);
|
||||
.make_aligned_uninit::<c64>(fourier_poly_size, align);
|
||||
// We perform the forward fft transform for the glwe polynomial
|
||||
let fourier = fft
|
||||
.forward_as_integer(
|
||||
@@ -424,7 +431,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
|
||||
ggsw_row,
|
||||
fourier,
|
||||
is_output_uninit,
|
||||
poly_size,
|
||||
fourier_poly_size,
|
||||
)
|
||||
};
|
||||
|
||||
@@ -445,7 +452,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
|
||||
izip!(
|
||||
out.as_mut_polynomial_list().iter_mut(),
|
||||
output_fft_buffer
|
||||
.into_chunks(poly_size / 2)
|
||||
.into_chunks(fourier_poly_size)
|
||||
.map(|slice| FourierPolynomialView { data: slice }),
|
||||
)
|
||||
.for_each(|(out, fourier)| {
|
||||
@@ -635,7 +642,7 @@ unsafe fn update_with_fmadd(
|
||||
ggsw_row: FourierGgswLevelRowView,
|
||||
fourier: &[c64],
|
||||
is_output_uninit: bool,
|
||||
poly_size: usize,
|
||||
fourier_poly_size: usize,
|
||||
) {
|
||||
#[allow(clippy::type_complexity)]
|
||||
let ptr_fn = || -> unsafe fn(&mut [MaybeUninit<c64>], &[c64], &[c64], bool) {
|
||||
@@ -657,8 +664,8 @@ unsafe fn update_with_fmadd(
|
||||
let ptr = ptr_fn();
|
||||
|
||||
izip!(
|
||||
output_fft_buffer.into_chunks(poly_size / 2),
|
||||
ggsw_row.data.into_chunks(poly_size / 2)
|
||||
output_fft_buffer.into_chunks(fourier_poly_size),
|
||||
ggsw_row.data.into_chunks(fourier_poly_size)
|
||||
)
|
||||
.for_each(|(output_fourier, ggsw_poly)| {
|
||||
ptr(output_fourier, ggsw_poly, fourier, is_output_uninit);
|
||||
|
||||
@@ -490,10 +490,13 @@ impl<C: Container<Element = c64>> FourierGgswCiphertextList<C> {
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
) -> Self {
|
||||
assert_eq!(polynomial_size.0 % 2, 0);
|
||||
assert_eq!(
|
||||
data.container_len(),
|
||||
count * polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0
|
||||
count
|
||||
* polynomial_size.to_fourier_polynomial_size().0
|
||||
* glwe_size.0
|
||||
* glwe_size.0
|
||||
* decomposition_level_count.0
|
||||
);
|
||||
|
||||
Self {
|
||||
@@ -588,7 +591,10 @@ impl<C: Container<Element = c64>> FourierGgswCiphertextList<C> {
|
||||
let decomposition_base_log = self.decomposition_base_log;
|
||||
|
||||
let (left, right) = self.fourier.data.split_at(
|
||||
mid * polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0,
|
||||
mid * polynomial_size.to_fourier_polynomial_size().0
|
||||
* glwe_size.0
|
||||
* glwe_size.0
|
||||
* decomposition_level_count.0,
|
||||
);
|
||||
(
|
||||
Self::new(
|
||||
|
||||
@@ -434,7 +434,11 @@ pub fn test_cmux_tree() {
|
||||
let mut ggsw_list = FourierGgswCiphertextList::new(
|
||||
vec![
|
||||
c64::default();
|
||||
nb_ggsw * polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * level.0
|
||||
nb_ggsw
|
||||
* polynomial_size.to_fourier_polynomial_size().0
|
||||
* glwe_size.0
|
||||
* glwe_size.0
|
||||
* level.0
|
||||
],
|
||||
nb_ggsw,
|
||||
glwe_size,
|
||||
|
||||
@@ -548,7 +548,7 @@ impl<C: Container<Element = c64>> serde::Serialize for FourierPolynomialList<C>
|
||||
let chunk_count = if polynomial_size.0 == 0 {
|
||||
0
|
||||
} else {
|
||||
data.len() / (polynomial_size.0 / 2)
|
||||
data.len() / (polynomial_size.to_fourier_polynomial_size().0)
|
||||
};
|
||||
|
||||
let mut state = serializer.serialize_seq(Some(2 + chunk_count))?;
|
||||
@@ -618,8 +618,10 @@ impl<'de, C: IntoContainerOwned<Element = c64>> serde::Deserialize<'de>
|
||||
}
|
||||
}
|
||||
|
||||
let mut data =
|
||||
C::collect((0..(polynomial_size.0 / 2 * chunk_count)).map(|_| c64::default()));
|
||||
let mut data = C::collect(
|
||||
(0..(polynomial_size.to_fourier_polynomial_size().0 * chunk_count))
|
||||
.map(|_| c64::default()),
|
||||
);
|
||||
|
||||
if chunk_count != 0 {
|
||||
let fft = Fft::new(polynomial_size);
|
||||
|
||||
Reference in New Issue
Block a user