feat(core): add variable Scalar type to PBS for input and output

This commit is contained in:
Arthur Meyre
2024-06-05 17:24:12 +02:00
parent 179fbfc9bb
commit cf5fd87efb
3 changed files with 215 additions and 48 deletions

View File

@@ -177,18 +177,26 @@ use dyn_stack::{PodStack, SizeOverflow, StackReq};
/// "Multiplication via PBS result is correct! Expected 6, got {pbs_multiplication_result}"
/// );
/// ```
pub fn blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
pub fn blind_rotate_assign<InputScalar, OutputScalar, InputCont, OutputCont, KeyCont>(
input: &LweCiphertext<InputCont>,
lut: &mut GlweCiphertext<OutputCont>,
fourier_bsk: &FourierLweBootstrapKey<KeyCont>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
InputScalar: UnsignedTorus + CastInto<usize>,
OutputScalar: UnsignedTorus,
InputCont: Container<Element = InputScalar>,
OutputCont: ContainerMut<Element = OutputScalar>,
KeyCont: Container<Element = c64>,
{
assert_eq!(input.ciphertext_modulus(), lut.ciphertext_modulus());
assert!(
input.ciphertext_modulus().is_power_of_two(),
"This operation requires the input to have a power of two modulus."
);
assert!(
lut.ciphertext_modulus().is_power_of_two(),
"This operation requires the lut to have a power of two modulus."
);
let mut buffers = ComputationBuffers::new();
@@ -196,7 +204,7 @@ pub fn blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
let fft = fft.as_view();
buffers.resize(
blind_rotate_assign_mem_optimized_requirement::<Scalar>(
blind_rotate_assign_mem_optimized_requirement::<OutputScalar>(
fourier_bsk.glwe_size(),
fourier_bsk.polynomial_size(),
fft,
@@ -213,7 +221,13 @@ pub fn blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
/// Memory optimized version of [`blind_rotate_assign`], the caller must provide
/// a properly configured [`FftView`] object and a `PodStack` used as a memory buffer having a
/// capacity at least as large as the result of [`blind_rotate_assign_mem_optimized_requirement`].
pub fn blind_rotate_assign_mem_optimized<Scalar, InputCont, OutputCont, KeyCont>(
pub fn blind_rotate_assign_mem_optimized<
InputScalar,
OutputScalar,
InputCont,
OutputCont,
KeyCont,
>(
input: &LweCiphertext<InputCont>,
lut: &mut GlweCiphertext<OutputCont>,
fourier_bsk: &FourierLweBootstrapKey<KeyCont>,
@@ -221,12 +235,20 @@ pub fn blind_rotate_assign_mem_optimized<Scalar, InputCont, OutputCont, KeyCont>
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
InputScalar: UnsignedTorus + CastInto<usize>,
OutputScalar: UnsignedTorus,
InputCont: Container<Element = InputScalar>,
OutputCont: ContainerMut<Element = OutputScalar>,
KeyCont: Container<Element = c64>,
{
assert_eq!(input.ciphertext_modulus(), lut.ciphertext_modulus());
assert!(
input.ciphertext_modulus().is_power_of_two(),
"This operation requires the input to have a power of two modulus."
);
assert!(
lut.ciphertext_modulus().is_power_of_two(),
"This operation requires the lut to have a power of two modulus."
);
// Blind rotate assign manages the rounding to go back to the proper torus if the ciphertext
// modulus is not the native one
@@ -236,12 +258,12 @@ pub fn blind_rotate_assign_mem_optimized<Scalar, InputCont, OutputCont, KeyCont>
}
/// Return the required memory for [`blind_rotate_assign_mem_optimized`].
pub fn blind_rotate_assign_mem_optimized_requirement<Scalar>(
pub fn blind_rotate_assign_mem_optimized_requirement<OutputScalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)
blind_rotate_assign_scratch::<OutputScalar>(glwe_size, polynomial_size, fft)
}
/// Compute the external product of `ggsw` and `glwe`, and add the result to `out`.
@@ -917,20 +939,31 @@ pub fn cmux_assign_mem_optimized_requirement<Scalar>(
/// "Multiplication via PBS result is correct! Expected 6, got {pbs_multiplication_result}"
/// );
/// ```
pub fn programmable_bootstrap_lwe_ciphertext<Scalar, InputCont, OutputCont, AccCont, KeyCont>(
pub fn programmable_bootstrap_lwe_ciphertext<
InputScalar,
OutputScalar,
InputCont,
OutputCont,
AccCont,
KeyCont,
>(
input: &LweCiphertext<InputCont>,
output: &mut LweCiphertext<OutputCont>,
accumulator: &GlweCiphertext<AccCont>,
fourier_bsk: &FourierLweBootstrapKey<KeyCont>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
AccCont: Container<Element = Scalar>,
InputScalar: UnsignedTorus + CastInto<usize>,
OutputScalar: UnsignedTorus,
InputCont: Container<Element = InputScalar>,
OutputCont: ContainerMut<Element = OutputScalar>,
AccCont: Container<Element = OutputScalar>,
KeyCont: Container<Element = c64>,
{
assert_eq!(input.ciphertext_modulus(), output.ciphertext_modulus());
assert!(
input.ciphertext_modulus().is_power_of_two(),
"This operation requires the input to have a power of two modulus."
);
assert_eq!(
output.ciphertext_modulus(),
accumulator.ciphertext_modulus()
@@ -942,7 +975,7 @@ pub fn programmable_bootstrap_lwe_ciphertext<Scalar, InputCont, OutputCont, AccC
let fft = fft.as_view();
buffers.resize(
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<Scalar>(
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<OutputScalar>(
fourier_bsk.glwe_size(),
fourier_bsk.polynomial_size(),
fft,
@@ -968,7 +1001,8 @@ pub fn programmable_bootstrap_lwe_ciphertext<Scalar, InputCont, OutputCont, AccC
/// capacity at least as large as the result of
/// [`programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement`].
pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized<
Scalar,
InputScalar,
OutputScalar,
InputCont,
OutputCont,
AccCont,
@@ -982,20 +1016,13 @@ pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized<
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
AccCont: Container<Element = Scalar>,
InputScalar: UnsignedTorus + CastInto<usize>,
OutputScalar: UnsignedTorus,
InputCont: Container<Element = InputScalar>,
OutputCont: ContainerMut<Element = OutputScalar>,
AccCont: Container<Element = OutputScalar>,
KeyCont: Container<Element = c64>,
{
assert_eq!(
input.ciphertext_modulus(),
output.ciphertext_modulus(),
"Mismatched moduli between input ({:?}) and output ({:?})",
input.ciphertext_modulus(),
output.ciphertext_modulus()
);
assert_eq!(
accumulator.ciphertext_modulus(),
output.ciphertext_modulus(),
@@ -1031,10 +1058,10 @@ pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized<
}
/// Return the required memory for [`programmable_bootstrap_lwe_ciphertext_mem_optimized`].
pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<Scalar>(
pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<OutputScalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft)
bootstrap_scratch::<OutputScalar>(glwe_size, polynomial_size, fft)
}

View File

@@ -700,3 +700,138 @@ fn lwe_encrypt_pbs_ntt64_decrypt_custom_mod(params: ClassicTestParams<u64>) {
create_parametrized_test!(lwe_encrypt_pbs_ntt64_decrypt_custom_mod {
TEST_PARAMS_3_BITS_SOLINAS_U64
});
#[test]
fn test_lwe_encrypt_pbs_switch_mod_switch_scalar_decrypt_custom_mod() {
let params = super::TEST_PARAMS_4_BITS_NATIVE_U64;
let lwe_dimension = params.lwe_dimension;
let lwe_noise_distribution_u64 = params.lwe_noise_distribution;
let lwe_noise_distribution_u32 =
DynamicDistribution::new_gaussian(lwe_noise_distribution_u64.gaussian_std_dev());
let glwe_noise_distribution = params.glwe_noise_distribution;
let output_ciphertext_modulus = params.ciphertext_modulus;
let message_modulus_log = params.message_modulus_log;
let output_encoding_with_padding = get_encoding_with_padding(output_ciphertext_modulus);
let glwe_dimension = params.glwe_dimension;
let polynomial_size = params.polynomial_size;
let pbs_base_log = params.pbs_base_log;
let pbs_level_count = params.pbs_level;
let input_ciphertext_modulus = CiphertextModulus::<u32>::new_native();
let input_encoding_with_padding = get_encoding_with_padding(input_ciphertext_modulus);
let mut rsc = TestResources::new();
let input_msg_modulus = 1u32 << message_modulus_log.0;
let output_msg_modulus = 1u64 << message_modulus_log.0;
let mut msg = input_msg_modulus;
let input_delta = input_encoding_with_padding / input_msg_modulus;
let output_delta = output_encoding_with_padding / output_msg_modulus;
let f = |x| x;
let accumulator = generate_programmable_bootstrap_glwe_lut(
polynomial_size,
glwe_dimension.to_glwe_size(),
output_msg_modulus.cast_into(),
output_ciphertext_modulus,
output_delta,
f,
);
assert!(check_encrypted_content_respects_mod(
&accumulator,
output_ciphertext_modulus
));
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key::<u32, _>(
lwe_dimension,
&mut rsc.secret_random_generator,
);
let lwe_sk_as_u64 = LweSecretKey::from_container(
lwe_sk
.as_ref()
.iter()
.copied()
.map(|x| x as u64)
.collect::<Vec<_>>(),
);
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut rsc.secret_random_generator,
);
let bsk = par_allocate_and_generate_new_lwe_bootstrap_key(
&lwe_sk_as_u64,
&glwe_sk,
pbs_base_log,
pbs_level_count,
glwe_noise_distribution,
output_ciphertext_modulus,
&mut rsc.encryption_random_generator,
);
assert!(check_encrypted_content_respects_mod(
&*bsk,
output_ciphertext_modulus
));
let mut fbsk = FourierLweBootstrapKey::new(
bsk.input_lwe_dimension(),
bsk.glwe_size(),
bsk.polynomial_size(),
bsk.decomposition_base_log(),
bsk.decomposition_level_count(),
);
par_convert_standard_lwe_bootstrap_key_to_fourier(&bsk, &mut fbsk);
drop(bsk);
while msg != 0 {
msg -= 1;
for _ in 0..NB_TESTS {
let plaintext = Plaintext(msg * input_delta);
let ct = allocate_and_encrypt_new_lwe_ciphertext(
&lwe_sk,
plaintext,
lwe_noise_distribution_u32,
input_ciphertext_modulus,
&mut rsc.encryption_random_generator,
);
assert!(check_encrypted_content_respects_mod(
&ct,
input_ciphertext_modulus
));
let mut output_ct = LweCiphertext::new(
0u64,
fbsk.output_lwe_dimension().to_lwe_size(),
output_ciphertext_modulus,
);
programmable_bootstrap_lwe_ciphertext(&ct, &mut output_ct, &accumulator, &fbsk);
assert!(check_encrypted_content_respects_mod(
&output_ct,
output_ciphertext_modulus
));
let decrypted = decrypt_lwe_ciphertext(&glwe_sk.as_lwe_secret_key(), &output_ct);
let decoded = round_decode(decrypted.0, output_delta) % output_msg_modulus;
assert_eq!(msg as u64, decoded);
}
// In coverage, we break after one while loop iteration, changing message values does not
// yield higher coverage
#[cfg(feature = "__coverage")]
break;
}
}

View File

@@ -239,13 +239,16 @@ pub fn bootstrap_scratch<Scalar>(
impl<'a> FourierLweBootstrapKeyView<'a> {
// CastInto required for PBS modulus switch which returns a usize
pub fn blind_rotate_assign<Scalar: UnsignedTorus + CastInto<usize>>(
pub fn blind_rotate_assign<InputScalar, OutputScalar>(
self,
mut lut: GlweCiphertextMutView<'_, Scalar>,
lwe: &[Scalar],
mut lut: GlweCiphertextMutView<'_, OutputScalar>,
lwe: &[InputScalar],
fft: FftView<'_>,
mut stack: PodStack<'_>,
) {
) where
InputScalar: UnsignedTorus + CastInto<usize>,
OutputScalar: UnsignedTorus,
{
let (lwe_body, lwe_mask) = lwe.split_last().unwrap();
let lut_poly_size = lut.polynomial_size();
@@ -273,7 +276,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> {
for (lwe_mask_element, bootstrap_key_ggsw) in izip!(lwe_mask.iter(), self.into_ggsw_iter())
{
if *lwe_mask_element != Scalar::ZERO {
if *lwe_mask_element != InputScalar::ZERO {
let monomial_degree =
MonomialDegree(pbs_modulus_switch(*lwe_mask_element, lut_poly_size));
@@ -321,20 +324,22 @@ impl<'a> FourierLweBootstrapKeyView<'a> {
}
}
pub fn bootstrap<Scalar>(
pub fn bootstrap<InputScalar, OutputScalar>(
self,
mut lwe_out: LweCiphertextMutView<'_, Scalar>,
lwe_in: LweCiphertextView<'_, Scalar>,
accumulator: GlweCiphertextView<'_, Scalar>,
mut lwe_out: LweCiphertextMutView<'_, OutputScalar>,
lwe_in: LweCiphertextView<'_, InputScalar>,
accumulator: GlweCiphertextView<'_, OutputScalar>,
fft: FftView<'_>,
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,
InputScalar: UnsignedTorus + CastInto<usize>,
OutputScalar: UnsignedTorus,
{
debug_assert_eq!(lwe_out.ciphertext_modulus(), lwe_in.ciphertext_modulus());
debug_assert_eq!(
lwe_in.ciphertext_modulus(),
assert!(lwe_in.ciphertext_modulus().is_power_of_two());
assert!(lwe_out.ciphertext_modulus().is_power_of_two());
assert_eq!(
lwe_out.ciphertext_modulus(),
accumulator.ciphertext_modulus()
);