mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
feat(core): add variable Scalar type to PBS for input and output
This commit is contained in:
@@ -177,18 +177,26 @@ use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
/// "Multiplication via PBS result is correct! Expected 6, got {pbs_multiplication_result}"
|
||||
/// );
|
||||
/// ```
|
||||
pub fn blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
pub fn blind_rotate_assign<InputScalar, OutputScalar, InputCont, OutputCont, KeyCont>(
|
||||
input: &LweCiphertext<InputCont>,
|
||||
lut: &mut GlweCiphertext<OutputCont>,
|
||||
fourier_bsk: &FourierLweBootstrapKey<KeyCont>,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputScalar: UnsignedTorus + CastInto<usize>,
|
||||
OutputScalar: UnsignedTorus,
|
||||
InputCont: Container<Element = InputScalar>,
|
||||
OutputCont: ContainerMut<Element = OutputScalar>,
|
||||
KeyCont: Container<Element = c64>,
|
||||
{
|
||||
assert_eq!(input.ciphertext_modulus(), lut.ciphertext_modulus());
|
||||
assert!(
|
||||
input.ciphertext_modulus().is_power_of_two(),
|
||||
"This operation requires the input to have a power of two modulus."
|
||||
);
|
||||
assert!(
|
||||
lut.ciphertext_modulus().is_power_of_two(),
|
||||
"This operation requires the lut to have a power of two modulus."
|
||||
);
|
||||
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
@@ -196,7 +204,7 @@ pub fn blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
let fft = fft.as_view();
|
||||
|
||||
buffers.resize(
|
||||
blind_rotate_assign_mem_optimized_requirement::<Scalar>(
|
||||
blind_rotate_assign_mem_optimized_requirement::<OutputScalar>(
|
||||
fourier_bsk.glwe_size(),
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
@@ -213,7 +221,13 @@ pub fn blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
/// Memory optimized version of [`blind_rotate_assign`], the caller must provide
|
||||
/// a properly configured [`FftView`] object and a `PodStack` used as a memory buffer having a
|
||||
/// capacity at least as large as the result of [`blind_rotate_assign_mem_optimized_requirement`].
|
||||
pub fn blind_rotate_assign_mem_optimized<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
pub fn blind_rotate_assign_mem_optimized<
|
||||
InputScalar,
|
||||
OutputScalar,
|
||||
InputCont,
|
||||
OutputCont,
|
||||
KeyCont,
|
||||
>(
|
||||
input: &LweCiphertext<InputCont>,
|
||||
lut: &mut GlweCiphertext<OutputCont>,
|
||||
fourier_bsk: &FourierLweBootstrapKey<KeyCont>,
|
||||
@@ -221,12 +235,20 @@ pub fn blind_rotate_assign_mem_optimized<Scalar, InputCont, OutputCont, KeyCont>
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputScalar: UnsignedTorus + CastInto<usize>,
|
||||
OutputScalar: UnsignedTorus,
|
||||
InputCont: Container<Element = InputScalar>,
|
||||
OutputCont: ContainerMut<Element = OutputScalar>,
|
||||
KeyCont: Container<Element = c64>,
|
||||
{
|
||||
assert_eq!(input.ciphertext_modulus(), lut.ciphertext_modulus());
|
||||
assert!(
|
||||
input.ciphertext_modulus().is_power_of_two(),
|
||||
"This operation requires the input to have a power of two modulus."
|
||||
);
|
||||
assert!(
|
||||
lut.ciphertext_modulus().is_power_of_two(),
|
||||
"This operation requires the lut to have a power of two modulus."
|
||||
);
|
||||
|
||||
// Blind rotate assign manages the rounding to go back to the proper torus if the ciphertext
|
||||
// modulus is not the native one
|
||||
@@ -236,12 +258,12 @@ pub fn blind_rotate_assign_mem_optimized<Scalar, InputCont, OutputCont, KeyCont>
|
||||
}
|
||||
|
||||
/// Return the required memory for [`blind_rotate_assign_mem_optimized`].
|
||||
pub fn blind_rotate_assign_mem_optimized_requirement<Scalar>(
|
||||
pub fn blind_rotate_assign_mem_optimized_requirement<OutputScalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
blind_rotate_assign_scratch::<OutputScalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
/// Compute the external product of `ggsw` and `glwe`, and add the result to `out`.
|
||||
@@ -917,20 +939,31 @@ pub fn cmux_assign_mem_optimized_requirement<Scalar>(
|
||||
/// "Multiplication via PBS result is correct! Expected 6, got {pbs_multiplication_result}"
|
||||
/// );
|
||||
/// ```
|
||||
pub fn programmable_bootstrap_lwe_ciphertext<Scalar, InputCont, OutputCont, AccCont, KeyCont>(
|
||||
pub fn programmable_bootstrap_lwe_ciphertext<
|
||||
InputScalar,
|
||||
OutputScalar,
|
||||
InputCont,
|
||||
OutputCont,
|
||||
AccCont,
|
||||
KeyCont,
|
||||
>(
|
||||
input: &LweCiphertext<InputCont>,
|
||||
output: &mut LweCiphertext<OutputCont>,
|
||||
accumulator: &GlweCiphertext<AccCont>,
|
||||
fourier_bsk: &FourierLweBootstrapKey<KeyCont>,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
AccCont: Container<Element = Scalar>,
|
||||
InputScalar: UnsignedTorus + CastInto<usize>,
|
||||
OutputScalar: UnsignedTorus,
|
||||
InputCont: Container<Element = InputScalar>,
|
||||
OutputCont: ContainerMut<Element = OutputScalar>,
|
||||
AccCont: Container<Element = OutputScalar>,
|
||||
KeyCont: Container<Element = c64>,
|
||||
{
|
||||
assert_eq!(input.ciphertext_modulus(), output.ciphertext_modulus());
|
||||
assert!(
|
||||
input.ciphertext_modulus().is_power_of_two(),
|
||||
"This operation requires the input to have a power of two modulus."
|
||||
);
|
||||
assert_eq!(
|
||||
output.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus()
|
||||
@@ -942,7 +975,7 @@ pub fn programmable_bootstrap_lwe_ciphertext<Scalar, InputCont, OutputCont, AccC
|
||||
let fft = fft.as_view();
|
||||
|
||||
buffers.resize(
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<Scalar>(
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<OutputScalar>(
|
||||
fourier_bsk.glwe_size(),
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
@@ -968,7 +1001,8 @@ pub fn programmable_bootstrap_lwe_ciphertext<Scalar, InputCont, OutputCont, AccC
|
||||
/// capacity at least as large as the result of
|
||||
/// [`programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement`].
|
||||
pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized<
|
||||
Scalar,
|
||||
InputScalar,
|
||||
OutputScalar,
|
||||
InputCont,
|
||||
OutputCont,
|
||||
AccCont,
|
||||
@@ -982,20 +1016,13 @@ pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized<
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
AccCont: Container<Element = Scalar>,
|
||||
InputScalar: UnsignedTorus + CastInto<usize>,
|
||||
OutputScalar: UnsignedTorus,
|
||||
InputCont: Container<Element = InputScalar>,
|
||||
OutputCont: ContainerMut<Element = OutputScalar>,
|
||||
AccCont: Container<Element = OutputScalar>,
|
||||
KeyCont: Container<Element = c64>,
|
||||
{
|
||||
assert_eq!(
|
||||
input.ciphertext_modulus(),
|
||||
output.ciphertext_modulus(),
|
||||
"Mismatched moduli between input ({:?}) and output ({:?})",
|
||||
input.ciphertext_modulus(),
|
||||
output.ciphertext_modulus()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
accumulator.ciphertext_modulus(),
|
||||
output.ciphertext_modulus(),
|
||||
@@ -1031,10 +1058,10 @@ pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized<
|
||||
}
|
||||
|
||||
/// Return the required memory for [`programmable_bootstrap_lwe_ciphertext_mem_optimized`].
|
||||
pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<Scalar>(
|
||||
pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<OutputScalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
bootstrap_scratch::<OutputScalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
@@ -700,3 +700,138 @@ fn lwe_encrypt_pbs_ntt64_decrypt_custom_mod(params: ClassicTestParams<u64>) {
|
||||
create_parametrized_test!(lwe_encrypt_pbs_ntt64_decrypt_custom_mod {
|
||||
TEST_PARAMS_3_BITS_SOLINAS_U64
|
||||
});
|
||||
|
||||
#[test]
|
||||
fn test_lwe_encrypt_pbs_switch_mod_switch_scalar_decrypt_custom_mod() {
|
||||
let params = super::TEST_PARAMS_4_BITS_NATIVE_U64;
|
||||
|
||||
let lwe_dimension = params.lwe_dimension;
|
||||
let lwe_noise_distribution_u64 = params.lwe_noise_distribution;
|
||||
let lwe_noise_distribution_u32 =
|
||||
DynamicDistribution::new_gaussian(lwe_noise_distribution_u64.gaussian_std_dev());
|
||||
let glwe_noise_distribution = params.glwe_noise_distribution;
|
||||
let output_ciphertext_modulus = params.ciphertext_modulus;
|
||||
let message_modulus_log = params.message_modulus_log;
|
||||
let output_encoding_with_padding = get_encoding_with_padding(output_ciphertext_modulus);
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let pbs_base_log = params.pbs_base_log;
|
||||
let pbs_level_count = params.pbs_level;
|
||||
|
||||
let input_ciphertext_modulus = CiphertextModulus::<u32>::new_native();
|
||||
let input_encoding_with_padding = get_encoding_with_padding(input_ciphertext_modulus);
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
let input_msg_modulus = 1u32 << message_modulus_log.0;
|
||||
let output_msg_modulus = 1u64 << message_modulus_log.0;
|
||||
let mut msg = input_msg_modulus;
|
||||
let input_delta = input_encoding_with_padding / input_msg_modulus;
|
||||
let output_delta = output_encoding_with_padding / output_msg_modulus;
|
||||
|
||||
let f = |x| x;
|
||||
|
||||
let accumulator = generate_programmable_bootstrap_glwe_lut(
|
||||
polynomial_size,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
output_msg_modulus.cast_into(),
|
||||
output_ciphertext_modulus,
|
||||
output_delta,
|
||||
f,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&accumulator,
|
||||
output_ciphertext_modulus
|
||||
));
|
||||
|
||||
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key::<u32, _>(
|
||||
lwe_dimension,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
|
||||
let lwe_sk_as_u64 = LweSecretKey::from_container(
|
||||
lwe_sk
|
||||
.as_ref()
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|x| x as u64)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
|
||||
let bsk = par_allocate_and_generate_new_lwe_bootstrap_key(
|
||||
&lwe_sk_as_u64,
|
||||
&glwe_sk,
|
||||
pbs_base_log,
|
||||
pbs_level_count,
|
||||
glwe_noise_distribution,
|
||||
output_ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&*bsk,
|
||||
output_ciphertext_modulus
|
||||
));
|
||||
|
||||
let mut fbsk = FourierLweBootstrapKey::new(
|
||||
bsk.input_lwe_dimension(),
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
);
|
||||
|
||||
par_convert_standard_lwe_bootstrap_key_to_fourier(&bsk, &mut fbsk);
|
||||
drop(bsk);
|
||||
|
||||
while msg != 0 {
|
||||
msg -= 1;
|
||||
for _ in 0..NB_TESTS {
|
||||
let plaintext = Plaintext(msg * input_delta);
|
||||
|
||||
let ct = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&lwe_sk,
|
||||
plaintext,
|
||||
lwe_noise_distribution_u32,
|
||||
input_ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&ct,
|
||||
input_ciphertext_modulus
|
||||
));
|
||||
|
||||
let mut output_ct = LweCiphertext::new(
|
||||
0u64,
|
||||
fbsk.output_lwe_dimension().to_lwe_size(),
|
||||
output_ciphertext_modulus,
|
||||
);
|
||||
|
||||
programmable_bootstrap_lwe_ciphertext(&ct, &mut output_ct, &accumulator, &fbsk);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&output_ct,
|
||||
output_ciphertext_modulus
|
||||
));
|
||||
|
||||
let decrypted = decrypt_lwe_ciphertext(&glwe_sk.as_lwe_secret_key(), &output_ct);
|
||||
|
||||
let decoded = round_decode(decrypted.0, output_delta) % output_msg_modulus;
|
||||
|
||||
assert_eq!(msg as u64, decoded);
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(feature = "__coverage")]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,13 +239,16 @@ pub fn bootstrap_scratch<Scalar>(
|
||||
|
||||
impl<'a> FourierLweBootstrapKeyView<'a> {
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
pub fn blind_rotate_assign<Scalar: UnsignedTorus + CastInto<usize>>(
|
||||
pub fn blind_rotate_assign<InputScalar, OutputScalar>(
|
||||
self,
|
||||
mut lut: GlweCiphertextMutView<'_, Scalar>,
|
||||
lwe: &[Scalar],
|
||||
mut lut: GlweCiphertextMutView<'_, OutputScalar>,
|
||||
lwe: &[InputScalar],
|
||||
fft: FftView<'_>,
|
||||
mut stack: PodStack<'_>,
|
||||
) {
|
||||
) where
|
||||
InputScalar: UnsignedTorus + CastInto<usize>,
|
||||
OutputScalar: UnsignedTorus,
|
||||
{
|
||||
let (lwe_body, lwe_mask) = lwe.split_last().unwrap();
|
||||
|
||||
let lut_poly_size = lut.polynomial_size();
|
||||
@@ -273,7 +276,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> {
|
||||
|
||||
for (lwe_mask_element, bootstrap_key_ggsw) in izip!(lwe_mask.iter(), self.into_ggsw_iter())
|
||||
{
|
||||
if *lwe_mask_element != Scalar::ZERO {
|
||||
if *lwe_mask_element != InputScalar::ZERO {
|
||||
let monomial_degree =
|
||||
MonomialDegree(pbs_modulus_switch(*lwe_mask_element, lut_poly_size));
|
||||
|
||||
@@ -321,20 +324,22 @@ impl<'a> FourierLweBootstrapKeyView<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bootstrap<Scalar>(
|
||||
pub fn bootstrap<InputScalar, OutputScalar>(
|
||||
self,
|
||||
mut lwe_out: LweCiphertextMutView<'_, Scalar>,
|
||||
lwe_in: LweCiphertextView<'_, Scalar>,
|
||||
accumulator: GlweCiphertextView<'_, Scalar>,
|
||||
mut lwe_out: LweCiphertextMutView<'_, OutputScalar>,
|
||||
lwe_in: LweCiphertextView<'_, InputScalar>,
|
||||
accumulator: GlweCiphertextView<'_, OutputScalar>,
|
||||
fft: FftView<'_>,
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
InputScalar: UnsignedTorus + CastInto<usize>,
|
||||
OutputScalar: UnsignedTorus,
|
||||
{
|
||||
debug_assert_eq!(lwe_out.ciphertext_modulus(), lwe_in.ciphertext_modulus());
|
||||
debug_assert_eq!(
|
||||
lwe_in.ciphertext_modulus(),
|
||||
assert!(lwe_in.ciphertext_modulus().is_power_of_two());
|
||||
assert!(lwe_out.ciphertext_modulus().is_power_of_two());
|
||||
assert_eq!(
|
||||
lwe_out.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus()
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user