mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 15:48:20 -05:00
Compare commits
2 Commits
al/div_mul
...
dt/ci/impr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c42dee5e51 | ||
|
|
37d25a659f |
8
Makefile
8
Makefile
@@ -314,13 +314,13 @@ test_core_crypto: install_rs_build_toolchain install_rs_check_toolchain
|
||||
.PHONY: test_core_crypto_cov # Run the tests of the core_crypto module with code coverage
|
||||
test_core_crypto_cov: install_rs_build_toolchain install_rs_check_toolchain install_tarpaulin
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) tarpaulin --profile $(CARGO_PROFILE) \
|
||||
--out xml --output-dir coverage/core_crypto --line --engine llvm --timeout 500 \
|
||||
--out xml --ignore-panics --output-dir coverage/core_crypto --line --engine llvm --timeout 500 \
|
||||
--implicit-test-threads $(COVERAGE_EXCLUDED_FILES) \
|
||||
--features=$(TARGET_ARCH_FEATURE),experimental,internal-keycache,__coverage \
|
||||
-p $(TFHE_SPEC) -- core_crypto::
|
||||
@if [[ "$(AVX512_SUPPORT)" == "ON" ]]; then \
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) tarpaulin --profile $(CARGO_PROFILE) \
|
||||
--out xml --output-dir coverage/core_crypto_avx512 --line --engine llvm --timeout 500 \
|
||||
--out xml --ignore-panics --output-dir coverage/core_crypto_avx512 --line --engine llvm --timeout 500 \
|
||||
--implicit-test-threads $(COVERAGE_EXCLUDED_FILES) \
|
||||
--features=$(TARGET_ARCH_FEATURE),experimental,internal-keycache,__coverage,$(AVX512_FEATURE) \
|
||||
-p $(TFHE_SPEC) -- core_crypto::; \
|
||||
@@ -334,7 +334,7 @@ test_boolean: install_rs_build_toolchain
|
||||
.PHONY: test_boolean_cov # Run the tests of the boolean module with code coverage
|
||||
test_boolean_cov: install_rs_check_toolchain install_tarpaulin
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) tarpaulin --profile $(CARGO_PROFILE) \
|
||||
--out xml --output-dir coverage/boolean --line --engine llvm --timeout 500 \
|
||||
--out xml --ignore-panics --output-dir coverage/boolean --line --engine llvm --timeout 500 \
|
||||
$(COVERAGE_EXCLUDED_FILES) \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,internal-keycache,__coverage \
|
||||
-p $(TFHE_SPEC) -- boolean::
|
||||
@@ -375,7 +375,7 @@ test_shortint: install_rs_build_toolchain
|
||||
.PHONY: test_shortint_cov # Run the tests of the shortint module with code coverage
|
||||
test_shortint_cov: install_rs_check_toolchain install_tarpaulin
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) tarpaulin --profile $(CARGO_PROFILE) \
|
||||
--out xml --output-dir coverage/shortint --line --engine llvm --timeout 500 \
|
||||
--out xml --ignore-panics --output-dir coverage/shortint --line --engine llvm --timeout 500 \
|
||||
$(COVERAGE_EXCLUDED_FILES) \
|
||||
--features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache,__coverage \
|
||||
-p $(TFHE_SPEC) -- shortint::
|
||||
|
||||
@@ -62,6 +62,7 @@ pub fn generate_keys<
|
||||
ClassicBootstrapKeys {
|
||||
small_lwe_sk: input_lwe_secret_key,
|
||||
big_lwe_sk: output_lwe_secret_key,
|
||||
glwe_sk: output_glwe_secret_key,
|
||||
bsk,
|
||||
fbsk,
|
||||
}
|
||||
@@ -324,3 +325,406 @@ fn lwe_encrypt_pbs_f128_decrypt_custom_mod_test_params_4_bits_native_u128() {
|
||||
fn lwe_encrypt_pbs_f128_decrypt_custom_mod_test_params_3_bits_127_u128() {
|
||||
lwe_encrypt_pbs_f128_decrypt_custom_mod(TEST_PARAMS_3_BITS_127_U128);
|
||||
}
|
||||
|
||||
fn blind_rotate<Scalar>(params: ClassicTestParams<Scalar>)
|
||||
where
|
||||
Scalar: UnsignedTorus
|
||||
+ Sync
|
||||
+ Send
|
||||
+ CastFrom<usize>
|
||||
+ CastInto<usize>
|
||||
+ Serialize
|
||||
+ DeserializeOwned,
|
||||
ClassicTestParams<Scalar>: KeyCacheAccess<Keys = ClassicBootstrapKeys<Scalar>>,
|
||||
{
|
||||
let lwe_modular_std_dev = params.lwe_modular_std_dev;
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let message_modulus_log = params.message_modulus_log;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
let input_message = msg_modulus.wrapping_sub(Scalar::ONE);
|
||||
|
||||
let f = |x: Scalar| Scalar::TWO * x;
|
||||
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
|
||||
let mut accumulator = generate_accumulator(
|
||||
polynomial_size,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
msg_modulus.cast_into(),
|
||||
ciphertext_modulus,
|
||||
delta,
|
||||
f,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&accumulator,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let mut keys_gen = |params| generate_keys(params, &mut rsc);
|
||||
|
||||
let keys = gen_keys_or_get_from_cache_if_enabled(params, &mut keys_gen);
|
||||
let (input_lwe_secret_key, output_lwe_secret_key, fbsk) =
|
||||
(keys.small_lwe_sk, keys.big_lwe_sk, keys.fbsk);
|
||||
|
||||
// Apply our encoding
|
||||
let plaintext = Plaintext(input_message * delta);
|
||||
|
||||
let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&input_lwe_secret_key,
|
||||
plaintext,
|
||||
lwe_modular_std_dev,
|
||||
ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
// Allocate the LweCiphertext to store the result of the PBS
|
||||
let mut pbs_multiplication_ct = LweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
blind_rotate_assign(&lwe_ciphertext_in, &mut accumulator, &fbsk);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(
|
||||
&accumulator,
|
||||
&mut pbs_multiplication_ct,
|
||||
MonomialDegree(0),
|
||||
);
|
||||
|
||||
// Decrypt the PBS multiplication result
|
||||
let pbs_multiplication_plaintext =
|
||||
decrypt_lwe_ciphertext(&output_lwe_secret_key, &pbs_multiplication_ct);
|
||||
|
||||
// Create a SignedDecomposer to perform the rounding of the decrypted plaintext
|
||||
// We pass a DecompositionBaseLog of message modulus (n) and a DecompositionLevelCount of 1
|
||||
// indicating we want to round the n+1 MSB, 1 bit of padding plus our n bits of message
|
||||
let signed_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(message_modulus_log.0 + 1),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
|
||||
// Round and remove our encoding
|
||||
let pbs_multiplication_result =
|
||||
signed_decomposer.closest_representable(pbs_multiplication_plaintext.0) / delta;
|
||||
|
||||
assert_eq!(pbs_multiplication_result, f(input_message));
|
||||
}
|
||||
|
||||
create_parametrized_test!(blind_rotate);
|
||||
|
||||
fn add_external_product<Scalar>(params: ClassicTestParams<Scalar>)
|
||||
where
|
||||
Scalar: UnsignedTorus
|
||||
+ Sync
|
||||
+ Send
|
||||
+ CastFrom<usize>
|
||||
+ CastInto<usize>
|
||||
+ Serialize
|
||||
+ DeserializeOwned,
|
||||
ClassicTestParams<Scalar>: KeyCacheAccess<Keys = ClassicBootstrapKeys<Scalar>>,
|
||||
{
|
||||
let glwe_size = GlweSize(2);
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let message_modulus_log = params.message_modulus_log;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let decomp_base_log = params.pbs_base_log;
|
||||
let decomp_level_count = params.pbs_level;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
let input_message = msg_modulus.wrapping_sub(Scalar::ONE);
|
||||
let delta: Scalar = (encoding_with_padding / msg_modulus) * Scalar::TWO;
|
||||
|
||||
let mut keys_gen = |params| generate_keys(params, &mut rsc);
|
||||
|
||||
let keys = gen_keys_or_get_from_cache_if_enabled(params, &mut keys_gen);
|
||||
let glwe_secret_key = keys.glwe_sk;
|
||||
|
||||
let msg_ggsw = Plaintext(input_message * delta);
|
||||
|
||||
// Create a new GgswCiphertext
|
||||
let mut ggsw = GgswCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_constant_ggsw_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut ggsw,
|
||||
msg_ggsw,
|
||||
glwe_modular_std_dev,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
let ct_plaintext = Plaintext(input_message * delta);
|
||||
|
||||
let ct_plaintexts = PlaintextList::new(ct_plaintext.0, PlaintextCount(polynomial_size.0));
|
||||
|
||||
let mut ct = GlweCiphertext::new(Scalar::ZERO, glwe_size, polynomial_size, ciphertext_modulus);
|
||||
|
||||
encrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut ct,
|
||||
&ct_plaintexts,
|
||||
glwe_modular_std_dev,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
let fft = Fft::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
let buffer_size_req = add_external_product_assign_mem_optimized_requirement::<Scalar>(
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required();
|
||||
|
||||
let buffer_size_req = buffer_size_req.max(
|
||||
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
buffers.resize(buffer_size_req);
|
||||
|
||||
let mut fourier_ggsw = FourierGgswCiphertext::new(
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
);
|
||||
|
||||
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized(
|
||||
&ggsw,
|
||||
&mut fourier_ggsw,
|
||||
fft,
|
||||
buffers.stack(),
|
||||
);
|
||||
|
||||
let mut ct_out = ct.clone();
|
||||
|
||||
add_external_product_assign(&mut ct_out, &fourier_ggsw, &ct);
|
||||
|
||||
let mut output_plaintext_list =
|
||||
PlaintextList::new(Scalar::ZERO, ct_plaintexts.plaintext_count());
|
||||
|
||||
decrypt_glwe_ciphertext(&glwe_secret_key, &ct_out, &mut output_plaintext_list);
|
||||
|
||||
let signed_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(message_modulus_log.0),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
|
||||
output_plaintext_list
|
||||
.iter_mut()
|
||||
.for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
|
||||
|
||||
// As we cloned the input ciphertext for the output, the external product result is added to the
|
||||
// originally contained value, hence why we expect ct_plaintext + ct_plaintext * msg_ggsw
|
||||
let expected = ct_plaintext.0 + ct_plaintext.0 * msg_ggsw.0;
|
||||
|
||||
assert!(output_plaintext_list.iter().all(|x| *x.0 == expected));
|
||||
}
|
||||
|
||||
// FIXME: test works with native value for ciphertext modulus but fails with custom one
|
||||
create_parametrized_test!(add_external_product {
|
||||
TEST_PARAMS_4_BITS_NATIVE_U64
|
||||
});
|
||||
|
||||
fn cmux<Scalar>(params: ClassicTestParams<Scalar>)
|
||||
where
|
||||
Scalar: UnsignedTorus
|
||||
+ Sync
|
||||
+ Send
|
||||
+ CastFrom<usize>
|
||||
+ CastInto<usize>
|
||||
+ Serialize
|
||||
+ DeserializeOwned,
|
||||
ClassicTestParams<Scalar>: KeyCacheAccess<Keys = ClassicBootstrapKeys<Scalar>>,
|
||||
{
|
||||
let glwe_size = GlweSize(2);
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let message_modulus_log = params.message_modulus_log;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let decomp_base_log = params.pbs_base_log;
|
||||
let decomp_level_count = params.pbs_level;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
let delta: Scalar = (encoding_with_padding / msg_modulus) * Scalar::TWO;
|
||||
|
||||
let mut keys_gen = |params| generate_keys(params, &mut rsc);
|
||||
|
||||
let keys = gen_keys_or_get_from_cache_if_enabled(params, &mut keys_gen);
|
||||
let glwe_secret_key = keys.glwe_sk;
|
||||
|
||||
// Create the plaintext
|
||||
let msg_ggsw_0 = Plaintext(Scalar::ZERO);
|
||||
|
||||
// Create a new GgswCiphertext
|
||||
let mut ggsw_0 = GgswCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_constant_ggsw_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut ggsw_0,
|
||||
msg_ggsw_0,
|
||||
glwe_modular_std_dev,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
// Create the plaintext
|
||||
let msg_ggsw_1 = Plaintext(Scalar::ONE);
|
||||
|
||||
// Create a new GgswCiphertext
|
||||
let mut ggsw_1 = GgswCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_constant_ggsw_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut ggsw_1,
|
||||
msg_ggsw_1,
|
||||
glwe_modular_std_dev,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
let ct0_plaintext = Plaintext(Scalar::ONE * delta);
|
||||
let ct1_plaintext = Plaintext(msg_modulus.wrapping_sub(Scalar::ONE) * delta);
|
||||
|
||||
let ct0_plaintexts = PlaintextList::new(ct0_plaintext.0, PlaintextCount(polynomial_size.0));
|
||||
let ct1_plaintexts = PlaintextList::new(ct1_plaintext.0, PlaintextCount(polynomial_size.0));
|
||||
|
||||
let mut ct0 = GlweCiphertext::new(Scalar::ZERO, glwe_size, polynomial_size, ciphertext_modulus);
|
||||
let mut ct1 = GlweCiphertext::new(Scalar::ZERO, glwe_size, polynomial_size, ciphertext_modulus);
|
||||
|
||||
encrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut ct0,
|
||||
&ct0_plaintexts,
|
||||
glwe_modular_std_dev,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
encrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut ct1,
|
||||
&ct1_plaintexts,
|
||||
glwe_modular_std_dev,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
let fft = Fft::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
let buffer_size_req =
|
||||
cmux_assign_mem_optimized_requirement::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required();
|
||||
|
||||
let buffer_size_req = buffer_size_req.max(
|
||||
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
buffers.resize(buffer_size_req);
|
||||
|
||||
let mut fourier_ggsw_0 = FourierGgswCiphertext::new(
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
);
|
||||
let mut fourier_ggsw_1 = FourierGgswCiphertext::new(
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
);
|
||||
|
||||
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized(
|
||||
&ggsw_0,
|
||||
&mut fourier_ggsw_0,
|
||||
fft,
|
||||
buffers.stack(),
|
||||
);
|
||||
|
||||
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized(
|
||||
&ggsw_1,
|
||||
&mut fourier_ggsw_1,
|
||||
fft,
|
||||
buffers.stack(),
|
||||
);
|
||||
|
||||
let mut ct0_clone = ct0.clone();
|
||||
let mut ct1_clone = ct1.clone();
|
||||
|
||||
cmux_assign(&mut ct0_clone, &mut ct1_clone, &fourier_ggsw_0);
|
||||
|
||||
let mut output_plaintext_list_0 =
|
||||
PlaintextList::new(Scalar::ZERO, ct0_plaintexts.plaintext_count());
|
||||
|
||||
decrypt_glwe_ciphertext(&glwe_secret_key, &ct0_clone, &mut output_plaintext_list_0);
|
||||
|
||||
let signed_decomposer =
|
||||
SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
|
||||
output_plaintext_list_0
|
||||
.iter_mut()
|
||||
.for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
|
||||
|
||||
assert!(output_plaintext_list_0
|
||||
.iter()
|
||||
.all(|x| *x.0 == ct0_plaintext.0));
|
||||
|
||||
cmux_assign_mem_optimized(&mut ct0, &mut ct1, &fourier_ggsw_1, fft, buffers.stack());
|
||||
|
||||
let mut output_plaintext_list_1 =
|
||||
PlaintextList::new(Scalar::ZERO, ct1_plaintexts.plaintext_count());
|
||||
|
||||
decrypt_glwe_ciphertext(&glwe_secret_key, &ct0, &mut output_plaintext_list_1);
|
||||
|
||||
output_plaintext_list_1
|
||||
.iter_mut()
|
||||
.for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
|
||||
|
||||
assert!(output_plaintext_list_1
|
||||
.iter()
|
||||
.all(|x| *x.0 == ct1_plaintext.0));
|
||||
}
|
||||
|
||||
create_parametrized_test!(cmux);
|
||||
|
||||
@@ -9,8 +9,8 @@ use serde::{Deserialize, Serialize};
|
||||
pub struct ClassicBootstrapKeys<Scalar: UnsignedInteger> {
|
||||
pub small_lwe_sk: LweSecretKey<Vec<Scalar>>,
|
||||
pub big_lwe_sk: LweSecretKey<Vec<Scalar>>,
|
||||
pub glwe_sk: GlweSecretKey<Vec<Scalar>>,
|
||||
pub bsk: LweBootstrapKeyOwned<Scalar>,
|
||||
|
||||
pub fbsk: FourierLweBootstrapKeyOwned,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user