mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 15:18:33 -05:00
feat(core): add add_external_product_assign
This commit is contained in:
@@ -6,7 +6,11 @@ use crate::core_crypto::commons::parameters::*;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::fft_impl::crypto::bootstrap::{bootstrap_scratch, FourierLweBootstrapKey};
|
||||
use crate::core_crypto::fft_impl::crypto::ggsw::{cmux, cmux_scratch, FourierGgswCiphertext};
|
||||
use crate::core_crypto::fft_impl::crypto::ggsw::{
|
||||
add_external_product_assign as impl_add_external_product_assign,
|
||||
add_external_product_assign_scratch as impl_add_external_product_assign_scratch, cmux,
|
||||
cmux_scratch, FourierGgswCiphertext,
|
||||
};
|
||||
use crate::core_crypto::fft_impl::crypto::wop_pbs::blind_rotate_assign_scratch;
|
||||
use crate::core_crypto::fft_impl::math::fft::{Fft, FftView};
|
||||
use concrete_fft::c64;
|
||||
@@ -79,6 +83,210 @@ pub fn blind_rotate_assign_mem_optimized_requirement<Scalar>(
|
||||
blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
/// Compute the external product of `ggsw` and `glwe`, and add the result to `out`.
|
||||
///
|
||||
/// Strictly speaking this function computes:
|
||||
///
|
||||
/// ```text
|
||||
/// out <- out + glwe * ggsw
|
||||
/// ```
|
||||
///
|
||||
/// If you want to manage the computation memory manually you can use
|
||||
/// [`add_external_product_assign_mem_optimized`].
|
||||
pub fn add_external_product_assign<Scalar, OutputGlweCont, InputGlweCont, GgswCont>(
|
||||
out: &mut GlweCiphertext<OutputGlweCont>,
|
||||
ggsw: &FourierGgswCiphertext<GgswCont>,
|
||||
glwe: &GlweCiphertext<InputGlweCont>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
OutputGlweCont: ContainerMut<Element = Scalar>,
|
||||
GgswCont: Container<Element = c64>,
|
||||
InputGlweCont: Container<Element = Scalar>,
|
||||
{
|
||||
let fft = Fft::new(ggsw.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
buffers.resize(
|
||||
add_external_product_assign_mem_optimized_requirement::<Scalar>(
|
||||
ggsw.glwe_size(),
|
||||
ggsw.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
add_external_product_assign_mem_optimized(out, ggsw, glwe, fft, buffers.stack());
|
||||
}
|
||||
|
||||
/// Memory optimized version of [`add_external_product_assign`], the caller must provide a properly
|
||||
/// configured [`FftView`] object and a `DynStack` used as a memory buffer having a capacity at
|
||||
/// least as large as the result of [`add_external_product_assign_mem_optimized_requirement`].
|
||||
///
|
||||
/// Compute the external product of `ggsw` and `glwe`, and add the result to `out`.
|
||||
///
|
||||
/// Strictly speaking this function computes:
|
||||
///
|
||||
/// ```text
|
||||
/// out <- out + glwe * ggsw
|
||||
/// ```
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use tfhe::core_crypto::prelude::*;
|
||||
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
|
||||
/// // computations
|
||||
/// // Define parameters for GgswCiphertext creation
|
||||
/// let glwe_size = GlweSize(2);
|
||||
/// let polynomial_size = PolynomialSize(2048);
|
||||
/// let decomp_base_log = DecompositionBaseLog(23);
|
||||
/// let decomp_level_count = DecompositionLevelCount(1);
|
||||
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
|
||||
///
|
||||
/// // Create the PRNG
|
||||
/// let mut seeder = new_seeder();
|
||||
/// let seeder = seeder.as_mut();
|
||||
/// let mut encryption_generator =
|
||||
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
/// let mut secret_generator =
|
||||
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
///
|
||||
/// // Create the GlweSecretKey
|
||||
/// let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
/// glwe_size.to_glwe_dimension(),
|
||||
/// polynomial_size,
|
||||
/// &mut secret_generator,
|
||||
/// );
|
||||
///
|
||||
/// // Create the plaintext, here we will multiply by 3
|
||||
/// let msg_ggsw = Plaintext(3u64);
|
||||
///
|
||||
/// // Create a new GgswCiphertext
|
||||
/// let mut ggsw = GgswCiphertext::new(
|
||||
/// 0u64,
|
||||
/// glwe_size,
|
||||
/// polynomial_size,
|
||||
/// decomp_base_log,
|
||||
/// decomp_level_count,
|
||||
/// );
|
||||
///
|
||||
/// encrypt_ggsw_ciphertext(
|
||||
/// &glwe_secret_key,
|
||||
/// &mut ggsw,
|
||||
/// msg_ggsw,
|
||||
/// glwe_modular_std_dev,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// let ct_plaintext = Plaintext(3 << 60);
|
||||
///
|
||||
/// let ct_plaintexts = PlaintextList::new(ct_plaintext.0, PlaintextCount(polynomial_size.0));
|
||||
///
|
||||
/// let mut ct = GlweCiphertext::new(0u64, glwe_size, polynomial_size);
|
||||
///
|
||||
/// encrypt_glwe_ciphertext(
|
||||
/// &glwe_secret_key,
|
||||
/// &mut ct,
|
||||
/// &ct_plaintexts,
|
||||
/// glwe_modular_std_dev,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// let fft = Fft::new(polynomial_size);
|
||||
/// let fft = fft.as_view();
|
||||
/// let mut buffers = ComputationBuffers::new();
|
||||
///
|
||||
/// let buffer_size_req = add_external_product_assign_mem_optimized_requirement::<u64>(
|
||||
/// glwe_size,
|
||||
/// polynomial_size,
|
||||
/// fft,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required();
|
||||
///
|
||||
/// let buffer_size_req = buffer_size_req.max(
|
||||
/// convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required(),
|
||||
/// );
|
||||
///
|
||||
/// buffers.resize(buffer_size_req);
|
||||
///
|
||||
/// let mut fourier_ggsw = FourierGgswCiphertext::new(
|
||||
/// glwe_size,
|
||||
/// polynomial_size,
|
||||
/// decomp_base_log,
|
||||
/// decomp_level_count,
|
||||
/// );
|
||||
///
|
||||
/// convert_standard_ggsw_ciphertext_to_fourier_mem_optimized(
|
||||
/// &ggsw,
|
||||
/// &mut fourier_ggsw,
|
||||
/// fft,
|
||||
/// buffers.stack(),
|
||||
/// );
|
||||
///
|
||||
/// let mut ct_out = ct.clone();
|
||||
///
|
||||
/// add_external_product_assign_mem_optimized(
|
||||
/// &mut ct_out,
|
||||
/// &fourier_ggsw,
|
||||
/// &ct,
|
||||
/// fft,
|
||||
/// buffers.stack(),
|
||||
/// );
|
||||
///
|
||||
/// let mut output_plaintext_list = PlaintextList::new(0u64, ct_plaintexts.plaintext_count());
|
||||
///
|
||||
/// decrypt_glwe_ciphertext(&glwe_secret_key, &ct_out, &mut output_plaintext_list);
|
||||
///
|
||||
/// let signed_decomposer =
|
||||
/// SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
///
|
||||
/// output_plaintext_list
|
||||
/// .iter_mut()
|
||||
/// .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
|
||||
///
|
||||
/// // As we cloned the input ciphertext for the output, the external product result is added to the
|
||||
/// // originally contained value, hence why we expect ct_plaintext + ct_plaintext * msg_ggsw
|
||||
/// assert!(output_plaintext_list
|
||||
/// .iter()
|
||||
/// .all(|x| *x.0 == ct_plaintext.0 + ct_plaintext.0 * msg_ggsw.0));
|
||||
/// ```
|
||||
pub fn add_external_product_assign_mem_optimized<Scalar, OutputGlweCont, InputGlweCont, GgswCont>(
|
||||
out: &mut GlweCiphertext<OutputGlweCont>,
|
||||
ggsw: &FourierGgswCiphertext<GgswCont>,
|
||||
glwe: &GlweCiphertext<InputGlweCont>,
|
||||
fft: FftView<'_>,
|
||||
stack: DynStack<'_>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
OutputGlweCont: ContainerMut<Element = Scalar>,
|
||||
GgswCont: Container<Element = c64>,
|
||||
InputGlweCont: Container<Element = Scalar>,
|
||||
{
|
||||
impl_add_external_product_assign(
|
||||
out.as_mut_view(),
|
||||
ggsw.as_view(),
|
||||
glwe.as_view(),
|
||||
fft,
|
||||
stack,
|
||||
)
|
||||
}
|
||||
|
||||
/// Return the required memory for [`add_external_product_assign_mem_optimized`].
|
||||
pub fn add_external_product_assign_mem_optimized_requirement<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
impl_add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
/// Compute a cmux on the input `ct0` and `ct1` using `ggsw` as selector.
|
||||
///
|
||||
/// `ct0` and `ct1` are both modified by this operation, the result is stored in `ct0` at the end
|
||||
@@ -94,8 +302,9 @@ pub fn blind_rotate_assign_mem_optimized_requirement<Scalar>(
|
||||
/// Therefore encrypting values other than 0 or 1 in the `ggsw` will yield a linear combination of
|
||||
/// `ct0` and `ct1`
|
||||
///
|
||||
/// From a logical point of view (without considering the side effects of the implementationw) the
|
||||
/// cmux operation does the following assuming a binary (0 or 1) value stored in the input `ggsw`:
|
||||
/// From a logical point of view (without considering the side effects of the implementation) the
|
||||
/// cmux operation does the following assuming a binary (0 or 1) value encrypted in the input
|
||||
/// `ggsw`:
|
||||
///
|
||||
/// ```text
|
||||
/// def cmux(ct0, ct1, ggsw):
|
||||
@@ -336,6 +545,7 @@ pub fn cmux_assign_mem_optimized<Scalar, Cont0, Cont1, GgswCont>(
|
||||
);
|
||||
}
|
||||
|
||||
/// Return the required memory for [`cmux_assign_mem_optimized`].
|
||||
pub fn cmux_assign_mem_optimized_requirement<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
|
||||
Reference in New Issue
Block a user