Compare commits

..

2 Commits

Author SHA1 Message Date
Arthur Meyre
943be06cda feat(core): bring extended PBS to core_crypto::experimental module
- for now no dedicated types have been created for the the extended
bootstrap, meaning an extended BSK is merely seen as a BSK/ExtensionFactor
couple
2026-04-27 18:25:19 +02:00
Arthur Meyre
c4eb3013a0 chore: fix spacing in Makefile 2026-04-27 16:07:16 +02:00
12 changed files with 1306 additions and 18 deletions

View File

@@ -106,8 +106,8 @@ install_rs_check_toolchain:
# We don't check that it exists, because we always want the latest
# and the command below will install/update
install_rs_latest_nightly_toolchain:
rustup toolchain install --profile default nightly || \
( echo "Unable to install nightly toolchain, check your rustup installation. \
rustup toolchain install --profile default nightly || \
( echo "Unable to install nightly toolchain, check your rustup installation. \
Rustup can be downloaded at https://rustup.rs/" && exit 1 )
.PHONY: install_rs_msrv_toolchain # Install the msrv toolchain
@@ -160,7 +160,7 @@ install_node:
$(SHELL) -i -c 'nvm install $(NODE_VERSION)' || \
( echo "Unable to install node, unknown error." && exit 1 )
.PHONY: node_version # Return Node version that will be installed
.PHONY: node_version # Return Node version that will be installed
node_version:
@echo "$(NODE_VERSION)"
@@ -195,7 +195,7 @@ install_typos_checker:
install_zizmor:
@./scripts/install_zizmor.sh --zizmor-version $(ZIZMOR_VERSION)
.PHONY: zizmor_version # Return zizmor version that will be installed
.PHONY: zizmor_version # Return zizmor version that will be installed
zizmor_version:
@echo "$(ZIZMOR_VERSION)"
@@ -334,7 +334,7 @@ fmt_c_tests:
fmt_toml: install_taplo
taplo fmt
.PHONY: check_fmt_c_tests # Check C tests format
.PHONY: check_fmt_c_tests # Check C tests format
check_fmt_c_tests:
find tfhe/c_api_tests/ -regex '.*\.\(cpp\|hpp\|cu\|c\|h\)' -exec clang-format --dry-run --Werror -style=file {} \;
@@ -1137,7 +1137,7 @@ test_high_level_api:
test_high_level_api_gpu_fast: install_cargo_nextest # Run all the GPU tests for high_level_api except test_uniformity for oprf which is too long
RUSTFLAGS="$(RUSTFLAGS)" cargo nextest run --cargo-profile $(CARGO_PROFILE) \
--test-threads=4 --features=integer,internal-keycache,gpu,zk-pok -p tfhe \
-E "test(/high_level_api::.*gpu.*/) and not test(/uniformity/)"
-E "test(/high_level_api::.*gpu.*/) and not test(/uniformity/)"
test_high_level_api_gpu: install_cargo_nextest # Run all the GPU tests for high_level_api
@@ -1153,12 +1153,12 @@ test_list_gpu: install_cargo_nextest
.PHONY: build_one_hl_api_test_gpu
build_one_hl_api_test_gpu:
RUSTFLAGS="$(RUSTFLAGS)" cargo test --no-run \
--features=integer,gpu-debug -vv -p tfhe -- "$${TEST}" --test-threads=1 --nocapture
--features=integer,gpu-debug -vv -p tfhe -- "$${TEST}" --test-threads=1 --nocapture
.PHONY: build_one_hl_api_test_fake_multi_gpu
build_one_hl_api_test_fake_multi_gpu:
RUSTFLAGS="$(RUSTFLAGS)" cargo test --no-run \
--features=integer,gpu-debug-fake-multi-gpu -vv -p tfhe -- "$${TEST}" --test-threads=1 --nocapture
--features=integer,gpu-debug-fake-multi-gpu -vv -p tfhe -- "$${TEST}" --test-threads=1 --nocapture
test_high_level_api_hpu: install_cargo_nextest
ifeq ($(HPU_CONFIG), v80)
@@ -1430,7 +1430,7 @@ run_web_js_api_parallel: build_web_js_api_parallel setup_venv
python ci/webdriver.py \
--browser-path $(browser_path) \
--driver-path $(driver_path) \
--browser-kind $(browser_kind) \
--browser-kind $(browser_kind) \
--server-cmd $(server_cmd) \
--server-workdir "$(WEB_SERVER_DIR)" \
--id-pattern $(filter) \
@@ -1583,8 +1583,12 @@ clippy_bench: install_rs_check_toolchain
--features=boolean,shortint,integer,internal-keycache,pbs-stats,zk-pok \
-p tfhe-benchmark -- --no-deps -D warnings
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
--features=shortint,internal-keycache \
--features=shortint,internal-keycache \
-p tfhe-benchmark -- --no-deps -D warnings
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
--features=experimental \
-p tfhe-benchmark -- --no-deps -D warnings
.PHONY: clippy_bench_gpu # Run clippy lints on tfhe-benchmark
clippy_bench_gpu: install_rs_check_toolchain
@@ -1975,14 +1979,14 @@ bench_hlapi_dex: install_rs_check_toolchain
.PHONY: bench_hlapi_dex_gpu # Run benchmarks for DEX operations on GPU
bench_hlapi_dex_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench hlapi-dex \
--features=integer,gpu,internal-keycache,pbs-stats -p tfhe-benchmark --profile release_lto_off --
.PHONY: bench_hlapi_dex_gpu_classical # Run benchmarks for DEX operations on GPU with classical parameters
bench_hlapi_dex_gpu_classical: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench hlapi-dex \
--features=integer,gpu,internal-keycache,pbs-stats -p tfhe-benchmark --profile release_lto_off --
@@ -2101,7 +2105,7 @@ bench_summary_gpu: install_rs_check_toolchain
--features=integer,gpu,internal-keycache -p tfhe-benchmark --profile release_lto_off -- '::transfer::overflow'
# DEX
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench hlapi-dex \
--features=integer,gpu,internal-keycache,pbs-stats -p tfhe-benchmark --profile release_lto_off -- '::no_cmux::'
@@ -2254,7 +2258,7 @@ pcc_batch_1:
pcc_batch_2:
$(call run_recipe_with_details,clippy)
$(call run_recipe_with_details,clippy_all_targets)
$(call run_recipe_with_details,check_fmt_js) # This needs to stay there, CI pipeline rely on this recipe to conditionally install Node
$(call run_recipe_with_details,check_fmt_js) # This needs to stay there, CI pipeline rely on this recipe to conditionally install Node
$(call run_recipe_with_details,clippy_test_vectors)
$(call run_recipe_with_details,check_test_vectors)
$(call run_recipe_with_details,clippy_wasm_par_mq)
@@ -2278,7 +2282,7 @@ pcc_batch_5:
$(call run_recipe_with_details,clippy_backward_compat_data)
$(call run_recipe_with_details,check_backward_compat_locks_did_not_change)
.PHONY: pcc_batch_6 # duration: 6'32''
.PHONY: pcc_batch_6 # duration: 6'32''
pcc_batch_6:
$(call run_recipe_with_details,clippy_boolean)
$(call run_recipe_with_details,clippy_c_api)

View File

@@ -55,6 +55,9 @@ avx512 = ["tfhe/avx512"]
pbs-stats = ["tfhe/pbs-stats"]
zk-pok = ["tfhe/zk-pok", "dep:tfhe-zk-pok"]
# experimental section
experimental = ["tfhe/experimental"]
[[bench]]
name = "boolean"
path = "benches/boolean/bench.rs"
@@ -205,6 +208,12 @@ path = "benches/core_crypto/pbs128_bench.rs"
harness = false
required-features = ["shortint", "internal-keycache"]
[[bench]]
name = "core_crypto-experimental_extended_pbs"
path = "benches/core_crypto/experimental_extended_pbs.rs"
harness = false
required-features = ["experimental"]
[[bench]]
name = "zk-msm"
path = "benches/zk/msm.rs"

View File

@@ -0,0 +1,350 @@
use benchmark::utilities::{get_bench_type, BenchmarkType};
use criterion::{black_box, Criterion, Throughput};
use rayon::prelude::*;
use tfhe::core_crypto::experimental::prelude::*;
use tfhe::core_crypto::prelude::*;
pub struct ExtendedPBSBenchParameters {
lwe_dimension: LweDimension,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
extension_factor: LweBootstrapExtensionFactor,
lwe_noise_distribution: DynamicDistribution<u64>,
glwe_noise_distribution: DynamicDistribution<u64>,
pbs_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
ks_level: DecompositionLevelCount,
message_modulus: CleartextModulus<MessageSpace>,
carry_modulus: CleartextModulus<CarrySpace>,
#[allow(dead_code)]
max_norm2: MaxNorm2,
#[allow(dead_code)]
log2_p_fail: f64,
ciphertext_modulus: CiphertextModulus<u64>,
encryption_key_choice: EncryptionKeyChoice,
}
// p-fail = 2^-128.147, algorithmic cost ~ 67456140, 2-norm = 5, extension factor = 16,
const BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128: ExtendedPBSBenchParameters =
ExtendedPBSBenchParameters {
lwe_dimension: LweDimension(884),
glwe_dimension: GlweDimension(4),
polynomial_size: PolynomialSize(512),
extension_factor: LweBootstrapExtensionFactor::new(16),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
1.4999005934396873e-06,
)),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
2.845267479601915e-15,
)),
pbs_base_log: DecompositionBaseLog(23),
pbs_level: DecompositionLevelCount(1),
ks_base_log: DecompositionBaseLog(5),
ks_level: DecompositionLevelCount(3),
message_modulus: CleartextModulus::new(4),
carry_modulus: CleartextModulus::new(4),
max_norm2: MaxNorm2(5f64),
log2_p_fail: -128.0,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
};
const KS_EPBS_BENCH_PARAMS: [(&str, &ExtendedPBSBenchParameters); 1] = [(
"BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128",
&BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128,
)];
fn get_encoding_with_padding<Scalar: UnsignedInteger>(
ciphertext_modulus: CiphertextModulus<Scalar>,
) -> Scalar {
if ciphertext_modulus.is_native_modulus() {
Scalar::ONE << (Scalar::BITS - 1)
} else {
Scalar::cast_from(ciphertext_modulus.get_custom_modulus() / 2)
}
}
fn ks_extended_pbs(criterion: &mut Criterion) {
let bench_name = "core_crypto::ks_extended_pbs";
let mut bench_group = criterion.benchmark_group(bench_name);
// Create the PRNG
let mut seeder = new_seeder();
let seeder = seeder.as_mut();
let mut encryption_generator =
EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
for (name, params) in KS_EPBS_BENCH_PARAMS {
let ExtendedPBSBenchParameters {
lwe_dimension,
glwe_dimension,
polynomial_size,
extension_factor,
lwe_noise_distribution,
glwe_noise_distribution,
pbs_base_log,
pbs_level,
ks_base_log,
ks_level,
message_modulus,
carry_modulus,
max_norm2: _,
log2_p_fail: _,
ciphertext_modulus,
encryption_key_choice,
} = *params;
let plaintext_modulus = message_modulus.0 * carry_modulus.0;
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
let delta = encoding_with_padding / plaintext_modulus;
assert!(matches!(encryption_key_choice, EncryptionKeyChoice::Big));
let lwe_sk =
allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut secret_generator,
);
let big_lwe_sk = glwe_sk.as_lwe_secret_key();
let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
&big_lwe_sk,
&lwe_sk,
ks_base_log,
ks_level,
lwe_noise_distribution,
ciphertext_modulus,
&mut encryption_generator,
);
let bsk = allocate_and_generate_new_lwe_bootstrap_key(
&lwe_sk,
&glwe_sk,
pbs_base_log,
pbs_level,
glwe_noise_distribution,
ciphertext_modulus,
&mut encryption_generator,
);
let mut fourier_bsk = FourierLweBootstrapKey::new(
bsk.input_lwe_dimension(),
bsk.glwe_size(),
bsk.polynomial_size(),
bsk.decomposition_base_log(),
bsk.decomposition_level_count(),
);
par_convert_standard_lwe_bootstrap_key_to_fourier(&bsk, &mut fourier_bsk);
let f = |x: u64| x;
let accumulator = generate_programmable_bootstrap_glwe_lut(
PolynomialSize(polynomial_size.0 * extension_factor.get()),
glwe_dimension.to_glwe_size(),
plaintext_modulus.cast_into(),
ciphertext_modulus,
delta,
f,
);
let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();
let mut buffers = ComputationBuffers::new();
// TODO: have req for main thread and for workers ?
use extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement as rq;
let requirement = rq::<u64>(
glwe_dimension.to_glwe_size(),
polynomial_size,
extension_factor,
fft,
)
.unaligned_bytes_required();
buffers.resize(requirement);
let mut thread_buffers = Vec::with_capacity(extension_factor.get());
for _ in 0..extension_factor.get() {
let mut buffer = ComputationBuffers::new();
buffer.resize(requirement);
thread_buffers.push(buffer);
}
let mut thread_stacks: Vec<_> = thread_buffers.iter_mut().map(|x| x.stack()).collect();
let bench_id;
match get_bench_type() {
BenchmarkType::Latency => {
let ct = allocate_and_encrypt_new_lwe_ciphertext(
&big_lwe_sk,
Plaintext(0),
lwe_noise_distribution,
ciphertext_modulus,
&mut encryption_generator,
);
let mut ks_buffer =
LweCiphertext::new(0, lwe_sk.lwe_dimension().to_lwe_size(), ciphertext_modulus);
let mut output_ct = ct.clone();
output_ct.as_mut().fill(0);
bench_id = format!("{bench_name}::{name}");
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut ks_buffer);
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized(
&fourier_bsk,
&mut output_ct,
&ct,
&accumulator,
extension_factor,
fft,
buffers.stack(),
&mut thread_stacks,
);
black_box(&mut output_ct);
})
});
}
BenchmarkType::Throughput => {
bench_id = format!("{bench_name}::throughput::{name}");
let mut setup = |batch_size: usize| {
let inputs = (0..batch_size)
.map(|_| {
let ct = allocate_and_encrypt_new_lwe_ciphertext(
&big_lwe_sk,
Plaintext(0),
lwe_noise_distribution,
ciphertext_modulus,
&mut encryption_generator,
);
let ks_buffer = LweCiphertext::new(
0,
lwe_sk.lwe_dimension().to_lwe_size(),
ciphertext_modulus,
);
let mut output_ct = ct.clone();
output_ct.as_mut().fill(0);
let accumulator = generate_programmable_bootstrap_glwe_lut(
PolynomialSize(polynomial_size.0 * extension_factor.get()),
glwe_dimension.to_glwe_size(),
plaintext_modulus.cast_into(),
ciphertext_modulus,
delta,
f,
);
let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();
let mut main_thread_buffer = ComputationBuffers::new();
let requirement = rq::<u64>(
glwe_dimension.to_glwe_size(),
polynomial_size,
extension_factor,
fft,
)
.unaligned_bytes_required();
main_thread_buffer.resize(requirement);
let mut thread_buffers = Vec::with_capacity(extension_factor.get());
for _ in 0..extension_factor.get() {
let mut buffer = ComputationBuffers::new();
buffer.resize(requirement);
thread_buffers.push(buffer);
}
(
ct,
ks_buffer,
output_ct,
accumulator,
main_thread_buffer,
thread_buffers,
)
})
.collect::<Vec<_>>();
inputs
};
type Res = Vec<(
LweCiphertext<Vec<u64>>, // Input
LweCiphertext<Vec<u64>>, // KS result
LweCiphertext<Vec<u64>>, // PBS result
GlweCiphertext<Vec<u64>>, // Accumulator
ComputationBuffers, // Main thread buffer
Vec<ComputationBuffers>, // Worker thread buffer
)>;
let run = |inputs: &mut Res| {
inputs.par_iter_mut().for_each(
|(
ct,
ks_buffer,
output_ct,
accumulator,
main_thread_buffer,
thread_buffers,
)| {
let mut thread_stacks: Vec<_> =
thread_buffers.iter_mut().map(|x| x.stack()).collect();
keyswitch_lwe_ciphertext(&ksk_big_to_small, ct, ks_buffer);
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized(
&fourier_bsk,
output_ct,
ct,
accumulator,
extension_factor,
fft,
main_thread_buffer.stack(),
&mut thread_stacks,
);
black_box(output_ct);
},
)
};
let elements = {
use benchmark::find_optimal_batch::find_optimal_batch;
find_optimal_batch(|inputs, _batch_size| run(inputs), &mut setup) as u64
};
bench_group.throughput(Throughput::Elements(elements));
bench_group.bench_function(&bench_id, |b| {
b.iter_batched(
|| setup(elements as usize),
|mut inputs| run(&mut inputs),
criterion::BatchSize::SmallInput,
)
});
}
};
}
}
pub fn extended_pbs_group() {
let mut criterion: Criterion<_> = (Criterion::default()
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60)))
.configure_from_args();
ks_extended_pbs(&mut criterion);
}
fn go_through_cpu_bench_groups() {
extended_pbs_group();
}
fn main() {
go_through_cpu_bench_groups();
Criterion::default().configure_from_args().final_summary();
}

View File

@@ -1,3 +1,4 @@
pub use benchmark_spec::{get_bench_type, BenchmarkType};
use benchmark_spec::{Backend, BenchmarkSpec, OperandType};
use criterion::Criterion;
use serde::Serialize;

View File

@@ -4,11 +4,11 @@ use crate::core_crypto::prelude::*;
/// Represent any kind of LWE ciphertext after a modulus switch operation.
///
/// This may be used as an input to the blind rotatation.
pub trait ModulusSwitchedLweCiphertext<Scalar> {
pub trait ModulusSwitchedLweCiphertext<SwitchedScalar> {
fn log_modulus(&self) -> CiphertextModulusLog;
fn lwe_dimension(&self) -> LweDimension;
fn body(&self) -> Scalar;
fn mask(&self) -> impl ExactSizeIterator<Item = Scalar> + '_;
fn body(&self) -> SwitchedScalar;
fn mask(&self) -> impl ExactSizeIterator<Item = SwitchedScalar> + '_;
}
pub fn lwe_ciphertext_modulus_switch<Scalar, SwitchedScalar, Cont>(

View File

@@ -11,6 +11,7 @@ use tfhe_versionable::Versionize;
pub use super::ciphertext_modulus::CiphertextModulus;
use super::traits::CastInto;
use crate::core_crypto::backward_compatibility::commons::parameters::*;
pub use crate::core_crypto::commons::dispersion::StandardDev;
/// The number plaintexts in a plaintext list.
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
@@ -479,6 +480,28 @@ impl NormalizedHammingWeightBound {
}
}
#[derive(Clone, Copy, Debug)]
pub struct MaxNorm2(pub f64);
pub trait CleartextMarker: Copy {}
#[derive(Clone, Copy, Debug)]
pub struct MessageSpace;
#[derive(Clone, Copy, Debug)]
pub struct CarrySpace;
impl CleartextMarker for MessageSpace {}
impl CleartextMarker for CarrySpace {}
#[derive(Clone, Copy, Debug)]
pub struct CleartextModulus<T: CleartextMarker>(pub u64, core::marker::PhantomData<T>);
impl<T: CleartextMarker> CleartextModulus<T> {
pub const fn new(value: u64) -> Self {
Self(value, core::marker::PhantomData)
}
}
#[cfg(test)]
mod test {
use super::*;

View File

@@ -0,0 +1,460 @@
//! Experimental module containing implementations of extended bootstrapping algorithms.
//!
//! See [this paper](https://eprint.iacr.org/2025/2214.pdf).
use crate::core_crypto::algorithms::glwe_linear_algebra::glwe_ciphertext_sub_assign;
use crate::core_crypto::algorithms::glwe_sample_extraction::extract_lwe_sample_from_glwe_ciphertext;
use crate::core_crypto::algorithms::modulus_switch::{
lwe_ciphertext_modulus_switch, ModulusSwitchedLweCiphertext,
};
use crate::core_crypto::algorithms::polynomial_algorithms::{
polynomial_wrapping_monic_monomial_div, polynomial_wrapping_monic_monomial_mul,
};
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::{CastInto, UnsignedInteger};
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweSize, MonomialDegree, PolynomialSize,
};
use crate::core_crypto::commons::traits::{
Container, ContainerMut, ContiguousEntityContainer, ContiguousEntityContainerMut,
};
use crate::core_crypto::entities::glwe_ciphertext::GlweCiphertext;
use crate::core_crypto::entities::lwe_ciphertext::LweCiphertext;
use crate::core_crypto::experimental::commons::parameters::LweBootstrapExtensionFactor;
use crate::core_crypto::fft_impl::fft64::c64;
use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::FourierLweBootstrapKey;
use crate::core_crypto::fft_impl::fft64::crypto::ggsw::{
add_external_product_assign, add_external_product_assign_scratch,
};
use crate::core_crypto::fft_impl::fft64::math::fft::FftView;
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::{PodStack, StackReq};
use itertools::izip;
use std::cell::UnsafeCell;
#[derive(Copy, Clone)]
pub struct UnsafeSlice<'a, T> {
slice: &'a [UnsafeCell<T>],
}
unsafe impl<T: Send + Sync> Send for UnsafeSlice<'_, T> {}
unsafe impl<T: Send + Sync> Sync for UnsafeSlice<'_, T> {}
impl<'a, T> UnsafeSlice<'a, T> {
pub fn new(slice: &'a mut [T]) -> Self {
let ptr = std::ptr::from_mut::<[T]>(slice) as *const [UnsafeCell<T>];
Self {
slice: unsafe { &*ptr },
}
}
/// # Safety
///
/// The caller must make sure that no concurrent read and write access occur for a given index
pub unsafe fn read(&self, idx: usize) -> &T {
let ptr = self.slice[idx].get();
&*ptr
}
/// # Safety
///
/// The caller must make sure that no concurrent read and write access occur for a given index
#[allow(clippy::mut_from_ref)] // that's the point of the UnsafeSlice in that case
pub unsafe fn write(&self, idx: usize) -> &mut T {
let ptr = self.slice[idx].get();
&mut *ptr
}
}
/// Requires all GlweCiphertexts in `split_luts` to have the same GlweSize and PolynomialSize and
/// the input `lut` PolynomialSize to be equal to : small_luts PolynomialSize times extension_factor
pub fn unchecked_split_extended_lut_into_small_luts<Scalar, ExtendedLutCont, SplitLutCont>(
lut: &GlweCiphertext<ExtendedLutCont>,
split_luts: &mut [GlweCiphertext<SplitLutCont>],
extension_factor: LweBootstrapExtensionFactor,
) where
Scalar: UnsignedInteger,
ExtendedLutCont: Container<Element = Scalar>,
SplitLutCont: ContainerMut<Element = Scalar>,
{
for (idx, &coeff) in lut.as_ref().iter().enumerate() {
let dst_lut = &mut split_luts[idx % extension_factor.get()];
dst_lut.as_mut()[idx / extension_factor.get()] = coeff;
}
}
/// Given an ai mod switched under the extended polynomial size N' = 2^nu * N
/// This function gives the monomial multiplication to apply to a small (split) lut to compute the
/// end of the rotation
///
/// N' = 2^nu * N
///
/// With 2^nu split small LUTs of size N, to compute the full rotation of the extended lut of size
/// N' by X^ai one needs to:
///
/// Move the lut at old_lut_idx to new_lut_idx:
/// new_lut_idx = (ai + old_lut_idx) % 2^nu
///
/// In the lut at new_lut_idx apply a monomial rotation whose exponent is:
/// exponent = (2^nu + (ai % 2N') - 1 - new_lut_idx)/2^nu
///
/// The resul of this function is the above exponent
pub(crate) fn small_lut_monomial_degree_from_extended_lut_monomial_degree(
extended_lut_monomial_degree: MonomialDegree,
extension_factor: LweBootstrapExtensionFactor,
// The index of the small lut being rotated
small_lut_idx: usize,
) -> MonomialDegree {
MonomialDegree(
(extension_factor.get() + extended_lut_monomial_degree.0 - 1 - small_lut_idx)
>> extension_factor.get().ilog2(),
)
}
#[allow(clippy::too_many_arguments)]
pub fn extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized<
Scalar,
KeyCont,
OutputCont,
InputCont,
AccCont,
>(
bsk: &FourierLweBootstrapKey<KeyCont>,
lwe_out: &mut LweCiphertext<OutputCont>,
lwe_in: &LweCiphertext<InputCont>,
accumulator: &GlweCiphertext<AccCont>,
extension_factor: LweBootstrapExtensionFactor,
fft: FftView<'_>,
stack: &mut PodStack,
thread_buffers: &mut [&mut PodStack],
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,
KeyCont: Container<Element = c64> + Sync,
OutputCont: ContainerMut<Element = Scalar>,
InputCont: Container<Element = Scalar>,
AccCont: Container<Element = Scalar>,
{
assert_eq!(
lwe_out.ciphertext_modulus(),
accumulator.ciphertext_modulus()
);
assert_eq!(
bsk.polynomial_size().0 * extension_factor.get(),
accumulator.polynomial_size().0
);
// For the extended bootstrap the mod switch is done using the extended polynomial size
let msed = lwe_ciphertext_modulus_switch(
lwe_in.as_view(),
accumulator
.polynomial_size()
.to_blind_rotation_input_modulus_log(),
);
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_impl(
bsk,
extension_factor,
accumulator,
&msed,
lwe_out.as_mut_view(),
fft,
stack,
thread_buffers,
);
}
#[allow(clippy::too_many_arguments)]
fn extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_impl<
Scalar,
KeyCont,
LutCont,
MsedLwe,
>(
bsk: &FourierLweBootstrapKey<KeyCont>,
extension_factor: LweBootstrapExtensionFactor,
input_lut: &GlweCiphertext<LutCont>,
msed_lwe: &MsedLwe,
mut lwe_out: LweCiphertext<&mut [Scalar]>,
fft: FftView<'_>,
stack: &mut PodStack,
thread_stacks: &mut [&mut PodStack],
) where
Scalar: UnsignedTorus + CastInto<usize>,
KeyCont: Container<Element = c64> + Sync,
LutCont: Container<Element = Scalar>,
MsedLwe: ModulusSwitchedLweCiphertext<usize> + Sync,
{
assert_eq!(thread_stacks.len(), extension_factor.get());
let lut_poly_size = input_lut.polynomial_size();
let ciphertext_modulus = input_lut.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
assert_eq!(
bsk.polynomial_size().0 * extension_factor.get(),
lut_poly_size.0
);
assert_eq!(
msed_lwe.log_modulus(),
lut_poly_size.to_blind_rotation_input_modulus_log()
);
let monomial_degree = MonomialDegree(msed_lwe.body());
let (lut_data, stack) = stack.make_aligned_raw(input_lut.as_ref().len(), CACHELINE_ALIGN);
let mut lut = GlweCiphertext::from_container(
&mut *lut_data,
input_lut.polynomial_size(),
input_lut.ciphertext_modulus(),
);
lut.as_mut_polynomial_list()
.iter_mut()
.zip(input_lut.as_polynomial_list().iter())
.for_each(|(mut dst_poly, src_poly)| {
polynomial_wrapping_monic_monomial_div(&mut dst_poly, &src_poly, monomial_degree)
});
// Remove mutability to make things more readable and use 0/1 notation to refer to computation
// buffers
let ct0 = lut;
let mut split_ct0 = Vec::with_capacity(extension_factor.get());
let mut split_ct1 = Vec::with_capacity(extension_factor.get());
let substack0 = {
let mut current_stack = stack;
for _ in 0..extension_factor.get() {
let (glwe_cont, substack) = current_stack.make_aligned_raw::<Scalar>(
bsk.glwe_size().0 * bsk.polynomial_size().0,
CACHELINE_ALIGN,
);
split_ct0.push(GlweCiphertext::from_container(
glwe_cont,
bsk.polynomial_size(),
ciphertext_modulus,
));
current_stack = substack;
}
current_stack
};
let _substack1 = {
let mut current_stack = substack0;
for _ in 0..extension_factor.get() {
let (glwe_cont, substack) = current_stack.make_aligned_raw::<Scalar>(
bsk.glwe_size().0 * bsk.polynomial_size().0,
CACHELINE_ALIGN,
);
split_ct1.push(GlweCiphertext::from_container(
glwe_cont,
bsk.polynomial_size(),
ciphertext_modulus,
));
current_stack = substack;
}
current_stack
};
unchecked_split_extended_lut_into_small_luts(&ct0, &mut split_ct0, extension_factor);
let thread_split_ct0 = UnsafeSlice::new(&mut split_ct0);
let thread_split_ct1 = UnsafeSlice::new(&mut split_ct1);
use std::sync::Barrier;
let barrier = Barrier::new(extension_factor.get());
std::thread::scope(|s| {
let thread_processing = |id: usize, stack: &mut PodStack| {
// ===== Setup thread local resources =====
let ct_dst_idx = id;
let extension_factor_rem_mask = extension_factor.get() - 1;
let (diff_dyn_array, stack) = stack.make_aligned_raw::<Scalar>(
bsk.glwe_size().0 * bsk.polynomial_size().0,
CACHELINE_ALIGN,
);
let mut diff_buffer = GlweCiphertext::from_container(
diff_dyn_array,
bsk.polynomial_size(),
ciphertext_modulus,
);
// ===== Perform the subpart of the rotation the thread is responsible for =====
for (mask_idx, (monomial_degree, ggsw)) in msed_lwe
.mask()
.map(MonomialDegree)
.zip(bsk.as_view().into_ggsw_iter())
.enumerate()
{
// Update the LUT the current thread looks at simulating the rotation in the
// extended LUT
//
// One thread is responsible of a destination LUT index at all times (alternating
// between two work buffers), we simulate the rotation by seeing which "source" LUT
// would end up in the index the thread is responsible for, this avoids memory
// copies
//
// The index mapping combined with a well chosen monomial multiplication simulates
// the rotation in the extended LUT from rotations of the small split LUTs
let ct_src_idx =
(ct_dst_idx.wrapping_sub(monomial_degree.0)) & extension_factor_rem_mask;
// The well chosen monomial degree allowing to complete the simulated rotation in
// the extended LUT, see function comment for the fomula
let small_monomial_degree =
small_lut_monomial_degree_from_extended_lut_monomial_degree(
monomial_degree,
extension_factor,
ct_dst_idx,
);
let rotated_buffer = {
let (src_to_rotate, dst_rotated, src_unrotated) = if (mask_idx % 2) == 0 {
unsafe {
(
thread_split_ct0.read(ct_src_idx),
thread_split_ct1.write(ct_dst_idx),
thread_split_ct0.read(ct_dst_idx),
)
}
} else {
unsafe {
(
thread_split_ct1.read(ct_src_idx),
thread_split_ct0.write(ct_dst_idx),
thread_split_ct1.read(ct_dst_idx),
)
}
};
// Prepare the destination for the ext prod by copying the unrotated
// accumulator there
dst_rotated.as_mut().copy_from_slice(src_unrotated.as_ref());
for (mut diff_poly, src_to_rotate_poly) in izip!(
diff_buffer.as_mut_polynomial_list().iter_mut(),
src_to_rotate.as_polynomial_list().iter(),
) {
// Rotate the lut that ends up in our slot and add to the
// destination
// This is computing Rot(ACCj)
polynomial_wrapping_monic_monomial_mul(
&mut diff_poly,
&src_to_rotate_poly,
small_monomial_degree,
);
}
// This is computing Rot(ACCj) - ACCj
glwe_ciphertext_sub_assign(&mut diff_buffer, src_unrotated);
dst_rotated
};
// ACCj ← BSKi x (Rot(ACCj) - ACCj) + ACCj
add_external_product_assign(
rotated_buffer.as_mut_view(),
ggsw,
diff_buffer.as_view(),
fft,
stack,
);
let _ = barrier.wait();
}
};
#[allow(clippy::needless_collect)]
let threads: Vec<_> = thread_stacks
.iter_mut()
.enumerate()
.map(|(id, stack)| s.spawn(move || thread_processing(id, stack)))
.collect();
for t in threads {
t.join().unwrap();
}
});
let lwe_dimension = bsk.input_lwe_dimension().0;
let buffer_to_use = if lwe_dimension.is_multiple_of(2) {
split_ct0
} else {
split_ct1
};
let mut lut_0 = buffer_to_use.into_iter().next().unwrap();
if !ciphertext_modulus.is_native_modulus() {
// When we convert back from the fourier domain, integer values will contain up to 53
// MSBs with information. In our representation of power of 2 moduli < native modulus we
// fill the MSBs and leave the LSBs empty, this usage of the signed decomposer allows to
// round while keeping the data in the MSBs
let signed_decomposer = SignedDecomposer::new(
DecompositionBaseLog(ciphertext_modulus.get_custom_modulus().ilog2() as usize),
DecompositionLevelCount(1),
);
lut_0
.as_mut()
.iter_mut()
.for_each(|x| *x = signed_decomposer.closest_representable(*x));
}
let split_accumulator = GlweCiphertext::from_container(
lut_0.as_ref().to_vec(),
lut_0.polynomial_size(),
lut_0.ciphertext_modulus(),
);
extract_lwe_sample_from_glwe_ciphertext(&split_accumulator, &mut lwe_out, MonomialDegree(0));
}
/// Return the required memory for
/// [`extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized`].
pub fn extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement<
OutputScalar,
>(
glwe_size: GlweSize,
small_polynomial_size: PolynomialSize,
extension_factor: LweBootstrapExtensionFactor,
fft: FftView<'_>,
) -> StackReq {
let local_accumulator_req = StackReq::new_aligned::<OutputScalar>(
glwe_size.0 * small_polynomial_size.0 * extension_factor.get(),
CACHELINE_ALIGN,
);
// split ct0 allocation
// we need (k + 1) * N * 2^nu
let split_ct0_req = StackReq::new_aligned::<OutputScalar>(
glwe_size.0 * small_polynomial_size.0 * extension_factor.get(),
CACHELINE_ALIGN,
);
// split ct1 allocation
// we need (k + 1) * N * 2^nu
let split_ct1_req = StackReq::new_aligned::<OutputScalar>(
glwe_size.0 * small_polynomial_size.0 * extension_factor.get(),
CACHELINE_ALIGN,
);
StackReq::all_of(&[
local_accumulator_req,
split_ct0_req,
split_ct1_req,
// diff_buffer allocation
// we need (k + 1) * N
StackReq::new_aligned::<OutputScalar>(
glwe_size.0 * small_polynomial_size.0,
CACHELINE_ALIGN,
),
// external product
add_external_product_assign_scratch::<OutputScalar>(glwe_size, small_polynomial_size, fft),
])
}

View File

@@ -1,5 +1,6 @@
pub mod glwe_fast_keyswitch;
pub mod glwe_partial_sample_extraction;
pub mod lwe_extended_programmable_bootstrapping;
pub mod lwe_shrinking_keyswitch;
pub mod lwe_shrinking_keyswitch_key_generation;
pub mod partial_glwe_secret_key_generation;
@@ -10,6 +11,7 @@ pub mod shared_lwe_secret_key_generation;
pub use glwe_fast_keyswitch::*;
pub use glwe_partial_sample_extraction::*;
pub use lwe_extended_programmable_bootstrapping::*;
pub use lwe_shrinking_keyswitch::*;
pub use lwe_shrinking_keyswitch_key_generation::*;
pub use partial_glwe_secret_key_generation::*;

View File

@@ -0,0 +1,342 @@
use super::*;
use crate::core_crypto::algorithms::glwe_secret_key_generation::allocate_and_generate_new_binary_glwe_secret_key;
use crate::core_crypto::algorithms::lwe_bootstrap_key_conversion::par_convert_standard_lwe_bootstrap_key_to_fourier;
use crate::core_crypto::algorithms::lwe_bootstrap_key_generation::par_allocate_and_generate_new_lwe_bootstrap_key;
use crate::core_crypto::algorithms::lwe_encryption::{
allocate_and_encrypt_new_lwe_ciphertext, decrypt_lwe_ciphertext,
};
use crate::core_crypto::algorithms::lwe_programmable_bootstrapping::generate_programmable_bootstrap_glwe_lut;
use crate::core_crypto::algorithms::lwe_secret_key_generation::allocate_and_generate_new_binary_lwe_secret_key;
use crate::core_crypto::algorithms::misc::check_encrypted_content_respects_mod;
use crate::core_crypto::algorithms::polynomial_algorithms;
use crate::core_crypto::commons::computation_buffers::ComputationBuffers;
use crate::core_crypto::commons::parameters::{
CarrySpace, CiphertextModulus, CleartextModulus, DecompositionBaseLog, DecompositionLevelCount,
DynamicDistribution, EncryptionKeyChoice, GlweDimension, GlweSize, LweDimension, MaxNorm2,
MessageSpace, MonomialDegree, PolynomialSize, StandardDev,
};
use crate::core_crypto::commons::traits::{CastInto, ContiguousEntityContainerMut};
use crate::core_crypto::entities::glwe_ciphertext::GlweCiphertext;
use crate::core_crypto::entities::lwe_ciphertext::LweCiphertext;
use crate::core_crypto::entities::plaintext::Plaintext;
use crate::core_crypto::experimental::algorithms::lwe_extended_programmable_bootstrapping::{
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized,
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement,
small_lut_monomial_degree_from_extended_lut_monomial_degree,
unchecked_split_extended_lut_into_small_luts,
};
use crate::core_crypto::experimental::commons::parameters::LweBootstrapExtensionFactor;
use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::FourierLweBootstrapKey;
use crate::core_crypto::fft_impl::fft64::math::fft::Fft;
// This test checks that the rotation in the extended polynomial ring is correctly performed by
// first moving small polynomials around and then doing a monomial mul in each small polynomial by a
// well chosen X^index
// N' = 2^nu * N
// new_lut_idx = (ai + old_lut_idx) % 2^nu
// small_lut_monomial_index = (2^nu + (ai % 2N') - 1 - new_lut_idx)/2^nu
// looks to work to multiply by X^ai
#[test]
fn test_monic_mul_split_eq() {
use rand::Rng;
let mut rng = rand::thread_rng();
let glwe_size = GlweSize(2);
let ciphertext_modulus = CiphertextModulus::<u64>::new_native();
for _ in 0..100 {
let small_poly_size = PolynomialSize(1usize << rng.gen_range(5..=12));
let extension_factor = LweBootstrapExtensionFactor::new(1 << rng.gen_range(1..=3));
let polynomial_size = PolynomialSize(small_poly_size.0 * extension_factor.get());
let mut extended_lut =
GlweCiphertext::new(0u64, glwe_size, polynomial_size, ciphertext_modulus);
extended_lut
.as_mut()
.iter_mut()
.for_each(|x| *x = rng.gen());
let extended_lut = extended_lut;
let mut small_luts =
vec![
GlweCiphertext::new(0u64, glwe_size, small_poly_size, ciphertext_modulus);
extension_factor.get()
];
unchecked_split_extended_lut_into_small_luts(
&extended_lut,
&mut small_luts,
extension_factor,
);
let ref_small_luts = small_luts;
for _ in 0..1000 {
// Modulo 2*N' to mimic the modulus switch for the extended N' = 2^nu * N polynomial
// size
let monomial_degree = MonomialDegree(rng.gen::<usize>() % (polynomial_size.0 * 2));
// Compute the reference for the rotation results as split luts
let ref_rotated_lut = {
let mut ref_rotated_lut = extended_lut.clone();
for mut polynomial_to_rotate in ref_rotated_lut.as_mut_polynomial_list().iter_mut()
{
polynomial_algorithms::polynomial_wrapping_monic_monomial_mul_assign(
&mut polynomial_to_rotate,
monomial_degree,
);
}
let mut ref_small_rotated_luts =
vec![
GlweCiphertext::new(0u64, glwe_size, small_poly_size, ciphertext_modulus);
extension_factor.get()
];
unchecked_split_extended_lut_into_small_luts(
&ref_rotated_lut,
&mut ref_small_rotated_luts,
extension_factor,
);
ref_rotated_lut
};
// Compute the rotation results with the trick
let small_luts_rotated_with_trick = {
// Copy the unmodified reference extended LUT as small luts
let mut small_luts_rotated_with_trick = ref_small_luts.clone();
// Rotate the lookup tables by the right amount to start simulating the extended LUT
// rotation
small_luts_rotated_with_trick
.rotate_right(monomial_degree.0 % extension_factor.get());
// Complete the rotation with the monomial degree "trick" for small LUTs, using a
// well chosen monomial degree for each small LUT rotation
for (lut_idx, small_lut) in small_luts_rotated_with_trick.iter_mut().enumerate() {
let small_monomial_degree =
small_lut_monomial_degree_from_extended_lut_monomial_degree(
monomial_degree,
extension_factor,
lut_idx,
);
for mut polynmial_to_rotate in small_lut.as_mut_polynomial_list().iter_mut() {
polynomial_algorithms::polynomial_wrapping_monic_monomial_mul_assign(
&mut polynmial_to_rotate,
small_monomial_degree,
)
}
}
small_luts_rotated_with_trick
};
// Verify our formulas work the way we expect by comparing to the reference not using
// the trick
assert_eq!(ref_small_rotated_luts, small_luts_rotated_with_trick);
}
}
}
#[allow(dead_code)]
#[derive(Clone, Copy, Debug)]
pub struct ExtendedPBSTestParameters {
lwe_dimension: LweDimension,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
extension_factor: LweBootstrapExtensionFactor,
lwe_noise_distribution: DynamicDistribution<u64>,
glwe_noise_distribution: DynamicDistribution<u64>,
pbs_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
ks_level: DecompositionLevelCount,
message_modulus: CleartextModulus<MessageSpace>,
carry_modulus: CleartextModulus<CarrySpace>,
max_norm2: MaxNorm2,
log2_p_fail: f64,
ciphertext_modulus: CiphertextModulus<u64>,
encryption_key_choice: EncryptionKeyChoice,
}
// p-fail = 2^-128.147, algorithmic cost ~ 67456140, 2-norm = 5, extension factor = 16,
pub const TEST_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128: ExtendedPBSTestParameters =
ExtendedPBSTestParameters {
lwe_dimension: LweDimension(884),
glwe_dimension: GlweDimension(4),
polynomial_size: PolynomialSize(512),
extension_factor: LweBootstrapExtensionFactor::new(16),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
1.4999005934396873e-06,
)),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
2.845267479601915e-15,
)),
pbs_base_log: DecompositionBaseLog(23),
pbs_level: DecompositionLevelCount(1),
ks_base_log: DecompositionBaseLog(5),
ks_level: DecompositionLevelCount(3),
message_modulus: CleartextModulus::new(4),
carry_modulus: CleartextModulus::new(4),
max_norm2: MaxNorm2(5f64),
log2_p_fail: -128.0,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
};
fn lwe_encrypt_extended_pbs_decrypt(params: ExtendedPBSTestParameters) {
let lwe_dimension = params.lwe_dimension;
let lwe_noise_distribution = params.lwe_noise_distribution;
let message_modulus = params.message_modulus.0;
let carry_modulus = params.carry_modulus.0;
let plaintext_modulus = message_modulus * carry_modulus;
let glwe_dimension = params.glwe_dimension;
let extension_factor = params.extension_factor;
let base_polynomial_size = params.polynomial_size;
let extended_polynomial_size = PolynomialSize(base_polynomial_size.0 * extension_factor.get());
let glwe_noise_distribution = params.glwe_noise_distribution;
let pbs_base_log = params.pbs_base_log;
let pbs_level_count = params.pbs_level;
let ciphertext_modulus = params.ciphertext_modulus;
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
let mut rsc = TestResources::new();
let f = |x: u64| x;
let delta: u64 = encoding_with_padding / plaintext_modulus;
let mut msg = plaintext_modulus;
let accumulator = generate_programmable_bootstrap_glwe_lut(
extended_polynomial_size,
glwe_dimension.to_glwe_size(),
plaintext_modulus.cast_into(),
ciphertext_modulus,
delta,
f,
);
assert!(check_encrypted_content_respects_mod(
&accumulator,
ciphertext_modulus
));
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
lwe_dimension,
&mut rsc.secret_random_generator,
);
let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
base_polynomial_size,
&mut rsc.secret_random_generator,
);
let output_lwe_secret_key = output_glwe_secret_key.as_lwe_secret_key();
let bsk = par_allocate_and_generate_new_lwe_bootstrap_key(
&input_lwe_secret_key,
&output_glwe_secret_key,
pbs_base_log,
pbs_level_count,
glwe_noise_distribution,
ciphertext_modulus,
&mut rsc.encryption_random_generator,
);
let mut fbsk = FourierLweBootstrapKey::new(
bsk.input_lwe_dimension(),
bsk.glwe_size(),
bsk.polynomial_size(),
bsk.decomposition_base_log(),
bsk.decomposition_level_count(),
);
par_convert_standard_lwe_bootstrap_key_to_fourier(&bsk, &mut fbsk);
let fft = Fft::new(base_polynomial_size);
let fft = fft.as_view();
let mut buffers = ComputationBuffers::new();
// TODO: have req for main thread and for workers ?
use extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement as rq;
let requirement = rq::<u64>(
glwe_dimension.to_glwe_size(),
base_polynomial_size,
extension_factor,
fft,
)
.unaligned_bytes_required();
buffers.resize(requirement);
let mut thread_buffers = Vec::with_capacity(extension_factor.get());
for _ in 0..extension_factor.get() {
let mut buffer = ComputationBuffers::new();
buffer.resize(requirement);
thread_buffers.push(buffer);
}
let mut thread_stacks: Vec<_> = thread_buffers.iter_mut().map(|x| x.stack()).collect();
while msg != 0 {
msg = msg.wrapping_sub(1);
for _ in 0..10 {
let plaintext = Plaintext(msg * delta);
let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
&input_lwe_secret_key,
plaintext,
lwe_noise_distribution,
ciphertext_modulus,
&mut rsc.encryption_random_generator,
);
assert!(check_encrypted_content_respects_mod(
&lwe_ciphertext_in,
ciphertext_modulus
));
let mut out_pbs_ct = LweCiphertext::new(
0,
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
ciphertext_modulus,
);
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized(
&fbsk,
&mut out_pbs_ct,
&lwe_ciphertext_in,
&accumulator,
extension_factor,
fft,
buffers.stack(),
&mut thread_stacks,
);
assert!(check_encrypted_content_respects_mod(
&out_pbs_ct,
ciphertext_modulus
));
let decrypted = decrypt_lwe_ciphertext(&output_lwe_secret_key, &out_pbs_ct);
let decoded = round_decode(decrypted.0, delta) % plaintext_modulus;
assert_eq!(decoded, f(msg));
}
// In coverage, we break after one while loop iteration, changing message values does not
// yield higher coverage
#[cfg(tarpaulin)]
break;
}
}
create_parameterized_test!(lwe_encrypt_extended_pbs_decrypt {
TEST_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128,
});

View File

@@ -1,5 +1,6 @@
use crate::core_crypto::algorithms::test::*;
use crate::core_crypto::experimental::prelude::*;
mod lwe_extended_programmable_bootstrapping;
mod lwe_fast_keyswitch;
mod lwe_stair_keyswitch;

View File

@@ -43,3 +43,36 @@ impl crate::core_crypto::commons::parameters::LweDimension {
LweSecretKeyUnsharedCoefCount(self.0 - shared_coef_count.0)
}
}
/// Parameter indicating by how much a LUT polynomial size is multiplied in the extended PBS
/// setting.
///
/// The extended PBS simulates the rotation of a larger LUT using only smaller LUTs. This
/// has nice noise properties in some cases. The extended LUT polynomial size is N' = tau * N where
/// tau is a power of 2, usual notation is tau = 2^nu, so N' = 2^nu * N.
///
/// See [this paper](https://eprint.iacr.org/2025/2214.pdf).
///
/// Currently the extension factor needs to be a power of two to keep a compatible power of two for
/// the extended LUT.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub struct LweBootstrapExtensionFactor(usize);
impl LweBootstrapExtensionFactor {
pub const fn new(value: usize) -> Self {
assert!(
value > 1,
"An LweBootstrapExtensionFactor <= 1 makes no sense."
);
assert!(
value.is_power_of_two(),
"LweBootstrapExtensionFactor needs to be a power of 2"
);
Self(value)
}
pub const fn get(&self) -> usize {
self.0
}
}

View File

@@ -1477,4 +1477,67 @@ pub(crate) mod test {
}
}
}
#[test]
fn oprf_test_uniformity_bits_ci_run_filter() {
let sample_count: usize = 100_000;
let p_value_limit: f64 = 0.000_01;
use crate::shortint::gen_keys;
use crate::shortint::parameters::test_params::{
TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128,
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
};
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
for params in [
ShortintParameterSet::from(
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
),
ShortintParameterSet::from(PARAM_MESSAGE_2_CARRY_2_KS_PBS),
ShortintParameterSet::from(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128),
] {
let (ck, sk) = gen_keys(params);
let oprf_ck = OprfPrivateKey::new(&ck);
let oprf_sk = OprfServerKey::new(&oprf_ck, &ck).unwrap();
let random_bits_per_block = sk.message_modulus.0.ilog2() as u64;
for random_bits_count in [3u64, 4] {
let expected_num_blocks =
random_bits_count.div_ceil(random_bits_per_block) as usize;
test_uniformity(
sample_count,
p_value_limit,
1 << random_bits_count,
|seed| {
let seed = (seed as u128).to_le_bytes();
let blocks = oprf_sk.generate_oblivious_pseudo_random_bits(
seed.as_slice(),
random_bits_count,
&sk,
);
let mut combined: u64 = 0;
let mut shift = 0u64;
for (i, block) in blocks.iter().enumerate() {
let decrypted = ck.decrypt_message_and_carry(block);
let block_bits = bits_in_block(
i,
expected_num_blocks,
random_bits_count,
random_bits_per_block,
);
combined |= decrypted << shift;
shift += block_bits;
}
combined
},
);
}
}
}
}