mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
Compare commits
4 Commits
main
...
am/feat/se
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48f464b467 | ||
|
|
1836f1c561 | ||
|
|
943be06cda | ||
|
|
c4eb3013a0 |
34
Makefile
34
Makefile
@@ -106,8 +106,8 @@ install_rs_check_toolchain:
|
||||
# We don't check that it exists, because we always want the latest
|
||||
# and the command below will install/update
|
||||
install_rs_latest_nightly_toolchain:
|
||||
rustup toolchain install --profile default nightly || \
|
||||
( echo "Unable to install nightly toolchain, check your rustup installation. \
|
||||
rustup toolchain install --profile default nightly || \
|
||||
( echo "Unable to install nightly toolchain, check your rustup installation. \
|
||||
Rustup can be downloaded at https://rustup.rs/" && exit 1 )
|
||||
|
||||
.PHONY: install_rs_msrv_toolchain # Install the msrv toolchain
|
||||
@@ -160,7 +160,7 @@ install_node:
|
||||
$(SHELL) -i -c 'nvm install $(NODE_VERSION)' || \
|
||||
( echo "Unable to install node, unknown error." && exit 1 )
|
||||
|
||||
.PHONY: node_version # Return Node version that will be installed
|
||||
.PHONY: node_version # Return Node version that will be installed
|
||||
node_version:
|
||||
@echo "$(NODE_VERSION)"
|
||||
|
||||
@@ -195,7 +195,7 @@ install_typos_checker:
|
||||
install_zizmor:
|
||||
@./scripts/install_zizmor.sh --zizmor-version $(ZIZMOR_VERSION)
|
||||
|
||||
.PHONY: zizmor_version # Return zizmor version that will be installed
|
||||
.PHONY: zizmor_version # Return zizmor version that will be installed
|
||||
zizmor_version:
|
||||
@echo "$(ZIZMOR_VERSION)"
|
||||
|
||||
@@ -334,7 +334,7 @@ fmt_c_tests:
|
||||
fmt_toml: install_taplo
|
||||
taplo fmt
|
||||
|
||||
.PHONY: check_fmt_c_tests # Check C tests format
|
||||
.PHONY: check_fmt_c_tests # Check C tests format
|
||||
check_fmt_c_tests:
|
||||
find tfhe/c_api_tests/ -regex '.*\.\(cpp\|hpp\|cu\|c\|h\)' -exec clang-format --dry-run --Werror -style=file {} \;
|
||||
|
||||
@@ -1137,7 +1137,7 @@ test_high_level_api:
|
||||
test_high_level_api_gpu_fast: install_cargo_nextest # Run all the GPU tests for high_level_api except test_uniformity for oprf which is too long
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo nextest run --cargo-profile $(CARGO_PROFILE) \
|
||||
--test-threads=4 --features=integer,internal-keycache,gpu,zk-pok -p tfhe \
|
||||
-E "test(/high_level_api::.*gpu.*/) and not test(/uniformity/)"
|
||||
-E "test(/high_level_api::.*gpu.*/) and not test(/uniformity/)"
|
||||
|
||||
|
||||
test_high_level_api_gpu: install_cargo_nextest # Run all the GPU tests for high_level_api
|
||||
@@ -1153,12 +1153,12 @@ test_list_gpu: install_cargo_nextest
|
||||
.PHONY: build_one_hl_api_test_gpu
|
||||
build_one_hl_api_test_gpu:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo test --no-run \
|
||||
--features=integer,gpu-debug -vv -p tfhe -- "$${TEST}" --test-threads=1 --nocapture
|
||||
--features=integer,gpu-debug -vv -p tfhe -- "$${TEST}" --test-threads=1 --nocapture
|
||||
|
||||
.PHONY: build_one_hl_api_test_fake_multi_gpu
|
||||
build_one_hl_api_test_fake_multi_gpu:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo test --no-run \
|
||||
--features=integer,gpu-debug-fake-multi-gpu -vv -p tfhe -- "$${TEST}" --test-threads=1 --nocapture
|
||||
--features=integer,gpu-debug-fake-multi-gpu -vv -p tfhe -- "$${TEST}" --test-threads=1 --nocapture
|
||||
|
||||
test_high_level_api_hpu: install_cargo_nextest
|
||||
ifeq ($(HPU_CONFIG), v80)
|
||||
@@ -1430,7 +1430,7 @@ run_web_js_api_parallel: build_web_js_api_parallel setup_venv
|
||||
python ci/webdriver.py \
|
||||
--browser-path $(browser_path) \
|
||||
--driver-path $(driver_path) \
|
||||
--browser-kind $(browser_kind) \
|
||||
--browser-kind $(browser_kind) \
|
||||
--server-cmd $(server_cmd) \
|
||||
--server-workdir "$(WEB_SERVER_DIR)" \
|
||||
--id-pattern $(filter) \
|
||||
@@ -1583,8 +1583,12 @@ clippy_bench: install_rs_check_toolchain
|
||||
--features=boolean,shortint,integer,internal-keycache,pbs-stats,zk-pok \
|
||||
-p tfhe-benchmark -- --no-deps -D warnings
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
|
||||
--features=shortint,internal-keycache \
|
||||
--features=shortint,internal-keycache \
|
||||
-p tfhe-benchmark -- --no-deps -D warnings
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
|
||||
--features=experimental \
|
||||
-p tfhe-benchmark -- --no-deps -D warnings
|
||||
|
||||
|
||||
.PHONY: clippy_bench_gpu # Run clippy lints on tfhe-benchmark
|
||||
clippy_bench_gpu: install_rs_check_toolchain
|
||||
@@ -1975,14 +1979,14 @@ bench_hlapi_dex: install_rs_check_toolchain
|
||||
|
||||
.PHONY: bench_hlapi_dex_gpu # Run benchmarks for DEX operations on GPU
|
||||
bench_hlapi_dex_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench hlapi-dex \
|
||||
--features=integer,gpu,internal-keycache,pbs-stats -p tfhe-benchmark --profile release_lto_off --
|
||||
|
||||
.PHONY: bench_hlapi_dex_gpu_classical # Run benchmarks for DEX operations on GPU with classical parameters
|
||||
bench_hlapi_dex_gpu_classical: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench hlapi-dex \
|
||||
--features=integer,gpu,internal-keycache,pbs-stats -p tfhe-benchmark --profile release_lto_off --
|
||||
@@ -2101,7 +2105,7 @@ bench_summary_gpu: install_rs_check_toolchain
|
||||
--features=integer,gpu,internal-keycache -p tfhe-benchmark --profile release_lto_off -- '::transfer::overflow'
|
||||
|
||||
# DEX
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_PARAM_TYPE=$(BENCH_PARAM_TYPE) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench hlapi-dex \
|
||||
--features=integer,gpu,internal-keycache,pbs-stats -p tfhe-benchmark --profile release_lto_off -- '::no_cmux::'
|
||||
@@ -2254,7 +2258,7 @@ pcc_batch_1:
|
||||
pcc_batch_2:
|
||||
$(call run_recipe_with_details,clippy)
|
||||
$(call run_recipe_with_details,clippy_all_targets)
|
||||
$(call run_recipe_with_details,check_fmt_js) # This needs to stay there, CI pipeline rely on this recipe to conditionally install Node
|
||||
$(call run_recipe_with_details,check_fmt_js) # This needs to stay there, CI pipeline rely on this recipe to conditionally install Node
|
||||
$(call run_recipe_with_details,clippy_test_vectors)
|
||||
$(call run_recipe_with_details,check_test_vectors)
|
||||
$(call run_recipe_with_details,clippy_wasm_par_mq)
|
||||
@@ -2278,7 +2282,7 @@ pcc_batch_5:
|
||||
$(call run_recipe_with_details,clippy_backward_compat_data)
|
||||
$(call run_recipe_with_details,check_backward_compat_locks_did_not_change)
|
||||
|
||||
.PHONY: pcc_batch_6 # duration: 6'32''
|
||||
.PHONY: pcc_batch_6 # duration: 6'32''
|
||||
pcc_batch_6:
|
||||
$(call run_recipe_with_details,clippy_boolean)
|
||||
$(call run_recipe_with_details,clippy_c_api)
|
||||
|
||||
@@ -55,6 +55,9 @@ avx512 = ["tfhe/avx512"]
|
||||
pbs-stats = ["tfhe/pbs-stats"]
|
||||
zk-pok = ["tfhe/zk-pok", "dep:tfhe-zk-pok"]
|
||||
|
||||
# experimental section
|
||||
experimental = ["tfhe/experimental"]
|
||||
|
||||
[[bench]]
|
||||
name = "boolean"
|
||||
path = "benches/boolean/bench.rs"
|
||||
@@ -205,6 +208,12 @@ path = "benches/core_crypto/pbs128_bench.rs"
|
||||
harness = false
|
||||
required-features = ["shortint", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "core_crypto-experimental_extended_pbs"
|
||||
path = "benches/core_crypto/experimental_extended_pbs.rs"
|
||||
harness = false
|
||||
required-features = ["experimental"]
|
||||
|
||||
[[bench]]
|
||||
name = "zk-msm"
|
||||
path = "benches/zk/msm.rs"
|
||||
|
||||
350
tfhe-benchmark/benches/core_crypto/experimental_extended_pbs.rs
Normal file
350
tfhe-benchmark/benches/core_crypto/experimental_extended_pbs.rs
Normal file
@@ -0,0 +1,350 @@
|
||||
use benchmark::utilities::{get_bench_type, BenchmarkType};
|
||||
use criterion::{black_box, Criterion, Throughput};
|
||||
use rayon::prelude::*;
|
||||
use tfhe::core_crypto::experimental::prelude::*;
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
|
||||
pub struct ExtendedPBSBenchParameters {
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
extension_factor: LweBootstrapExtensionFactor,
|
||||
lwe_noise_distribution: DynamicDistribution<u64>,
|
||||
glwe_noise_distribution: DynamicDistribution<u64>,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
ks_level: DecompositionLevelCount,
|
||||
message_modulus: CleartextModulus<MessageSpace>,
|
||||
carry_modulus: CleartextModulus<CarrySpace>,
|
||||
#[allow(dead_code)]
|
||||
max_norm2: MaxNorm2,
|
||||
#[allow(dead_code)]
|
||||
log2_p_fail: f64,
|
||||
ciphertext_modulus: CiphertextModulus<u64>,
|
||||
encryption_key_choice: EncryptionKeyChoice,
|
||||
}
|
||||
|
||||
// p-fail = 2^-128.147, algorithmic cost ~ 67456140, 2-norm = 5, extension factor = 16,
|
||||
const BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128: ExtendedPBSBenchParameters =
|
||||
ExtendedPBSBenchParameters {
|
||||
lwe_dimension: LweDimension(884),
|
||||
glwe_dimension: GlweDimension(4),
|
||||
polynomial_size: PolynomialSize(512),
|
||||
extension_factor: LweBootstrapExtensionFactor::new(16),
|
||||
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
|
||||
1.4999005934396873e-06,
|
||||
)),
|
||||
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
|
||||
2.845267479601915e-15,
|
||||
)),
|
||||
pbs_base_log: DecompositionBaseLog(23),
|
||||
pbs_level: DecompositionLevelCount(1),
|
||||
ks_base_log: DecompositionBaseLog(5),
|
||||
ks_level: DecompositionLevelCount(3),
|
||||
message_modulus: CleartextModulus::new(4),
|
||||
carry_modulus: CleartextModulus::new(4),
|
||||
max_norm2: MaxNorm2(5f64),
|
||||
log2_p_fail: -128.0,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
};
|
||||
|
||||
const KS_EPBS_BENCH_PARAMS: [(&str, &ExtendedPBSBenchParameters); 1] = [(
|
||||
"BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128",
|
||||
&BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128,
|
||||
)];
|
||||
|
||||
fn get_encoding_with_padding<Scalar: UnsignedInteger>(
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
) -> Scalar {
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
Scalar::ONE << (Scalar::BITS - 1)
|
||||
} else {
|
||||
Scalar::cast_from(ciphertext_modulus.get_custom_modulus() / 2)
|
||||
}
|
||||
}
|
||||
|
||||
fn ks_extended_pbs(criterion: &mut Criterion) {
|
||||
let bench_name = "core_crypto::ks_extended_pbs";
|
||||
let mut bench_group = criterion.benchmark_group(bench_name);
|
||||
|
||||
// Create the PRNG
|
||||
let mut seeder = new_seeder();
|
||||
let seeder = seeder.as_mut();
|
||||
let mut encryption_generator =
|
||||
EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
|
||||
let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
|
||||
|
||||
for (name, params) in KS_EPBS_BENCH_PARAMS {
|
||||
let ExtendedPBSBenchParameters {
|
||||
lwe_dimension,
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
extension_factor,
|
||||
lwe_noise_distribution,
|
||||
glwe_noise_distribution,
|
||||
pbs_base_log,
|
||||
pbs_level,
|
||||
ks_base_log,
|
||||
ks_level,
|
||||
message_modulus,
|
||||
carry_modulus,
|
||||
max_norm2: _,
|
||||
log2_p_fail: _,
|
||||
ciphertext_modulus,
|
||||
encryption_key_choice,
|
||||
} = *params;
|
||||
|
||||
let plaintext_modulus = message_modulus.0 * carry_modulus.0;
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
let delta = encoding_with_padding / plaintext_modulus;
|
||||
|
||||
assert!(matches!(encryption_key_choice, EncryptionKeyChoice::Big));
|
||||
|
||||
let lwe_sk =
|
||||
allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
|
||||
|
||||
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
&mut secret_generator,
|
||||
);
|
||||
let big_lwe_sk = glwe_sk.as_lwe_secret_key();
|
||||
let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
&big_lwe_sk,
|
||||
&lwe_sk,
|
||||
ks_base_log,
|
||||
ks_level,
|
||||
lwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let bsk = allocate_and_generate_new_lwe_bootstrap_key(
|
||||
&lwe_sk,
|
||||
&glwe_sk,
|
||||
pbs_base_log,
|
||||
pbs_level,
|
||||
glwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let mut fourier_bsk = FourierLweBootstrapKey::new(
|
||||
bsk.input_lwe_dimension(),
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
);
|
||||
par_convert_standard_lwe_bootstrap_key_to_fourier(&bsk, &mut fourier_bsk);
|
||||
|
||||
let f = |x: u64| x;
|
||||
|
||||
let accumulator = generate_programmable_bootstrap_glwe_lut(
|
||||
PolynomialSize(polynomial_size.0 * extension_factor.get()),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
plaintext_modulus.cast_into(),
|
||||
ciphertext_modulus,
|
||||
delta,
|
||||
f,
|
||||
);
|
||||
|
||||
let fft = Fft::new(fourier_bsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
// TODO: have req for main thread and for workers ?
|
||||
use extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement as rq;
|
||||
|
||||
let requirement = rq::<u64>(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
extension_factor,
|
||||
fft,
|
||||
)
|
||||
.unaligned_bytes_required();
|
||||
|
||||
buffers.resize(requirement);
|
||||
|
||||
let mut thread_buffers = Vec::with_capacity(extension_factor.get());
|
||||
for _ in 0..extension_factor.get() {
|
||||
let mut buffer = ComputationBuffers::new();
|
||||
buffer.resize(requirement);
|
||||
thread_buffers.push(buffer);
|
||||
}
|
||||
|
||||
let mut thread_stacks: Vec<_> = thread_buffers.iter_mut().map(|x| x.stack()).collect();
|
||||
|
||||
let bench_id;
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
let ct = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&big_lwe_sk,
|
||||
Plaintext(0),
|
||||
lwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let mut ks_buffer =
|
||||
LweCiphertext::new(0, lwe_sk.lwe_dimension().to_lwe_size(), ciphertext_modulus);
|
||||
|
||||
let mut output_ct = ct.clone();
|
||||
output_ct.as_mut().fill(0);
|
||||
|
||||
bench_id = format!("{bench_name}::{name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut ks_buffer);
|
||||
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized(
|
||||
&fourier_bsk,
|
||||
&mut output_ct,
|
||||
&ct,
|
||||
&accumulator,
|
||||
extension_factor,
|
||||
fft,
|
||||
buffers.stack(),
|
||||
&mut thread_stacks,
|
||||
);
|
||||
black_box(&mut output_ct);
|
||||
})
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
bench_id = format!("{bench_name}::throughput::{name}");
|
||||
let mut setup = |batch_size: usize| {
|
||||
let inputs = (0..batch_size)
|
||||
.map(|_| {
|
||||
let ct = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&big_lwe_sk,
|
||||
Plaintext(0),
|
||||
lwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let ks_buffer = LweCiphertext::new(
|
||||
0,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let mut output_ct = ct.clone();
|
||||
output_ct.as_mut().fill(0);
|
||||
|
||||
let accumulator = generate_programmable_bootstrap_glwe_lut(
|
||||
PolynomialSize(polynomial_size.0 * extension_factor.get()),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
plaintext_modulus.cast_into(),
|
||||
ciphertext_modulus,
|
||||
delta,
|
||||
f,
|
||||
);
|
||||
|
||||
let fft = Fft::new(fourier_bsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
|
||||
let mut main_thread_buffer = ComputationBuffers::new();
|
||||
|
||||
let requirement = rq::<u64>(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
extension_factor,
|
||||
fft,
|
||||
)
|
||||
.unaligned_bytes_required();
|
||||
|
||||
main_thread_buffer.resize(requirement);
|
||||
|
||||
let mut thread_buffers = Vec::with_capacity(extension_factor.get());
|
||||
for _ in 0..extension_factor.get() {
|
||||
let mut buffer = ComputationBuffers::new();
|
||||
buffer.resize(requirement);
|
||||
thread_buffers.push(buffer);
|
||||
}
|
||||
|
||||
(
|
||||
ct,
|
||||
ks_buffer,
|
||||
output_ct,
|
||||
accumulator,
|
||||
main_thread_buffer,
|
||||
thread_buffers,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
inputs
|
||||
};
|
||||
type Res = Vec<(
|
||||
LweCiphertext<Vec<u64>>, // Input
|
||||
LweCiphertext<Vec<u64>>, // KS result
|
||||
LweCiphertext<Vec<u64>>, // PBS result
|
||||
GlweCiphertext<Vec<u64>>, // Accumulator
|
||||
ComputationBuffers, // Main thread buffer
|
||||
Vec<ComputationBuffers>, // Worker thread buffer
|
||||
)>;
|
||||
let run = |inputs: &mut Res| {
|
||||
inputs.par_iter_mut().for_each(
|
||||
|(
|
||||
ct,
|
||||
ks_buffer,
|
||||
output_ct,
|
||||
accumulator,
|
||||
main_thread_buffer,
|
||||
thread_buffers,
|
||||
)| {
|
||||
let mut thread_stacks: Vec<_> =
|
||||
thread_buffers.iter_mut().map(|x| x.stack()).collect();
|
||||
keyswitch_lwe_ciphertext(&ksk_big_to_small, ct, ks_buffer);
|
||||
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized(
|
||||
&fourier_bsk,
|
||||
output_ct,
|
||||
ct,
|
||||
accumulator,
|
||||
extension_factor,
|
||||
fft,
|
||||
main_thread_buffer.stack(),
|
||||
&mut thread_stacks,
|
||||
);
|
||||
black_box(output_ct);
|
||||
},
|
||||
)
|
||||
};
|
||||
let elements = {
|
||||
use benchmark::find_optimal_batch::find_optimal_batch;
|
||||
find_optimal_batch(|inputs, _batch_size| run(inputs), &mut setup) as u64
|
||||
};
|
||||
bench_group.throughput(Throughput::Elements(elements));
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter_batched(
|
||||
|| setup(elements as usize),
|
||||
|mut inputs| run(&mut inputs),
|
||||
criterion::BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extended_pbs_group() {
|
||||
let mut criterion: Criterion<_> = (Criterion::default()
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60)))
|
||||
.configure_from_args();
|
||||
ks_extended_pbs(&mut criterion);
|
||||
}
|
||||
|
||||
fn go_through_cpu_bench_groups() {
|
||||
extended_pbs_group();
|
||||
}
|
||||
|
||||
fn main() {
|
||||
go_through_cpu_bench_groups();
|
||||
|
||||
Criterion::default().configure_from_args().final_summary();
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub use benchmark_spec::{get_bench_type, BenchmarkType};
|
||||
use benchmark_spec::{Backend, BenchmarkSpec, OperandType};
|
||||
use criterion::Criterion;
|
||||
use serde::Serialize;
|
||||
|
||||
@@ -4,11 +4,11 @@ use crate::core_crypto::prelude::*;
|
||||
/// Represent any kind of LWE ciphertext after a modulus switch operation.
|
||||
///
|
||||
/// This may be used as an input to the blind rotatation.
|
||||
pub trait ModulusSwitchedLweCiphertext<Scalar> {
|
||||
pub trait ModulusSwitchedLweCiphertext<SwitchedScalar> {
|
||||
fn log_modulus(&self) -> CiphertextModulusLog;
|
||||
fn lwe_dimension(&self) -> LweDimension;
|
||||
fn body(&self) -> Scalar;
|
||||
fn mask(&self) -> impl ExactSizeIterator<Item = Scalar> + '_;
|
||||
fn body(&self) -> SwitchedScalar;
|
||||
fn mask(&self) -> impl ExactSizeIterator<Item = SwitchedScalar> + '_;
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_modulus_switch<Scalar, SwitchedScalar, Cont>(
|
||||
|
||||
@@ -11,6 +11,7 @@ use tfhe_versionable::Versionize;
|
||||
pub use super::ciphertext_modulus::CiphertextModulus;
|
||||
use super::traits::CastInto;
|
||||
use crate::core_crypto::backward_compatibility::commons::parameters::*;
|
||||
pub use crate::core_crypto::commons::dispersion::StandardDev;
|
||||
|
||||
/// The number plaintexts in a plaintext list.
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
|
||||
@@ -479,6 +480,28 @@ impl NormalizedHammingWeightBound {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct MaxNorm2(pub f64);
|
||||
|
||||
pub trait CleartextMarker: Copy {}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct MessageSpace;
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct CarrySpace;
|
||||
|
||||
impl CleartextMarker for MessageSpace {}
|
||||
impl CleartextMarker for CarrySpace {}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct CleartextModulus<T: CleartextMarker>(pub u64, core::marker::PhantomData<T>);
|
||||
|
||||
impl<T: CleartextMarker> CleartextModulus<T> {
|
||||
pub const fn new(value: u64) -> Self {
|
||||
Self(value, core::marker::PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
@@ -0,0 +1,849 @@
|
||||
//! Experimental module containing implementations of extended bootstrapping algorithms.
|
||||
//!
|
||||
//! See [this paper](https://eprint.iacr.org/2025/2214.pdf).
|
||||
|
||||
use crate::core_crypto::algorithms::glwe_linear_algebra::glwe_ciphertext_sub_assign;
|
||||
use crate::core_crypto::algorithms::glwe_sample_extraction::extract_lwe_sample_from_glwe_ciphertext;
|
||||
use crate::core_crypto::algorithms::modulus_switch::{
|
||||
lwe_ciphertext_modulus_switch, ModulusSwitchedLweCiphertext,
|
||||
};
|
||||
use crate::core_crypto::algorithms::polynomial_algorithms::{
|
||||
polynomial_wrapping_monic_monomial_div, polynomial_wrapping_monic_monomial_mul,
|
||||
};
|
||||
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::numeric::{CastInto, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
DecompositionBaseLog, DecompositionLevelCount, GlweSize, MonomialDegree, PolynomialSize,
|
||||
};
|
||||
use crate::core_crypto::commons::traits::{
|
||||
Container, ContainerMut, ContiguousEntityContainer, ContiguousEntityContainerMut,
|
||||
};
|
||||
use crate::core_crypto::entities::glwe_ciphertext::{GlweCiphertext, GlweCiphertextOwned};
|
||||
use crate::core_crypto::entities::lwe_ciphertext::{LweCiphertext, LweCiphertextView};
|
||||
use crate::core_crypto::experimental::commons::parameters::{
|
||||
LweBootstrapExtensionFactor, LweExtendedBootstrapShortcutCoeffCount,
|
||||
};
|
||||
use crate::core_crypto::fft_impl::fft64::c64;
|
||||
use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::FourierLweBootstrapKey;
|
||||
use crate::core_crypto::fft_impl::fft64::crypto::ggsw::{
|
||||
add_external_product_assign, add_external_product_assign_scratch,
|
||||
};
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::FftView;
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use itertools::izip;
|
||||
|
||||
use std::cell::UnsafeCell;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct UnsafeSlice<'a, T> {
|
||||
slice: &'a [UnsafeCell<T>],
|
||||
}
|
||||
unsafe impl<T: Send + Sync> Send for UnsafeSlice<'_, T> {}
|
||||
unsafe impl<T: Send + Sync> Sync for UnsafeSlice<'_, T> {}
|
||||
|
||||
impl<'a, T> UnsafeSlice<'a, T> {
|
||||
pub fn new(slice: &'a mut [T]) -> Self {
|
||||
let ptr = std::ptr::from_mut::<[T]>(slice) as *const [UnsafeCell<T>];
|
||||
Self {
|
||||
slice: unsafe { &*ptr },
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// The caller must make sure that no concurrent read and write access occur for a given index
|
||||
pub unsafe fn read(&self, idx: usize) -> &T {
|
||||
let ptr = self.slice[idx].get();
|
||||
&*ptr
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// The caller must make sure that no concurrent read and write access occur for a given index
|
||||
#[allow(clippy::mut_from_ref)] // that's the point of the UnsafeSlice in that case
|
||||
pub unsafe fn write(&self, idx: usize) -> &mut T {
|
||||
let ptr = self.slice[idx].get();
|
||||
&mut *ptr
|
||||
}
|
||||
}
|
||||
|
||||
/// Requires all GlweCiphertexts in `split_luts` to have the same GlweSize and PolynomialSize and
|
||||
/// the input `lut` PolynomialSize to be equal to : small_luts PolynomialSize times extension_factor
|
||||
pub fn unchecked_split_extended_lut_into_small_luts<Scalar, ExtendedLutCont, SplitLutCont>(
|
||||
lut: &GlweCiphertext<ExtendedLutCont>,
|
||||
split_luts: &mut [GlweCiphertext<SplitLutCont>],
|
||||
extension_factor: LweBootstrapExtensionFactor,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
ExtendedLutCont: Container<Element = Scalar>,
|
||||
SplitLutCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
for (idx, &coeff) in lut.as_ref().iter().enumerate() {
|
||||
let dst_lut = &mut split_luts[idx % extension_factor.get()];
|
||||
dst_lut.as_mut()[idx / extension_factor.get()] = coeff;
|
||||
}
|
||||
}
|
||||
|
||||
/// Given an ai mod switched under the extended polynomial size N' = 2^nu * N
|
||||
/// This function gives the monomial multiplication to apply to a small (split) lut to compute the
|
||||
/// end of the rotation
|
||||
///
|
||||
/// N' = 2^nu * N
|
||||
///
|
||||
/// With 2^nu split small LUTs of size N, to compute the full rotation of the extended lut of size
|
||||
/// N' by X^ai one needs to:
|
||||
///
|
||||
/// Move the lut at old_lut_idx to new_lut_idx:
|
||||
/// new_lut_idx = (ai + old_lut_idx) % 2^nu
|
||||
///
|
||||
/// In the lut at new_lut_idx apply a monomial rotation whose exponent is:
|
||||
/// exponent = (2^nu + (ai % 2N') - 1 - new_lut_idx)/2^nu
|
||||
///
|
||||
/// The resul of this function is the above exponent
|
||||
pub(crate) fn small_lut_monomial_degree_from_extended_lut_monomial_degree(
|
||||
extended_lut_monomial_degree: MonomialDegree,
|
||||
extension_factor: LweBootstrapExtensionFactor,
|
||||
// The index of the small lut being rotated
|
||||
small_lut_idx: usize,
|
||||
) -> MonomialDegree {
|
||||
MonomialDegree(
|
||||
(extension_factor.get() + extended_lut_monomial_degree.0 - 1 - small_lut_idx)
|
||||
>> extension_factor.get().ilog2(),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized<
|
||||
Scalar,
|
||||
KeyCont,
|
||||
OutputCont,
|
||||
InputCont,
|
||||
AccCont,
|
||||
>(
|
||||
bsk: &FourierLweBootstrapKey<KeyCont>,
|
||||
lwe_out: &mut LweCiphertext<OutputCont>,
|
||||
lwe_in: &LweCiphertext<InputCont>,
|
||||
accumulator: &GlweCiphertext<AccCont>,
|
||||
extension_factor: LweBootstrapExtensionFactor,
|
||||
fft: FftView<'_>,
|
||||
stack: &mut PodStack,
|
||||
thread_buffers: &mut [&mut PodStack],
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
KeyCont: Container<Element = c64> + Sync,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
AccCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(
|
||||
lwe_out.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
bsk.polynomial_size().0 * extension_factor.get(),
|
||||
accumulator.polynomial_size().0
|
||||
);
|
||||
|
||||
// For the extended bootstrap the mod switch is done using the extended polynomial size
|
||||
let msed = lwe_ciphertext_modulus_switch(
|
||||
lwe_in.as_view(),
|
||||
accumulator
|
||||
.polynomial_size()
|
||||
.to_blind_rotation_input_modulus_log(),
|
||||
);
|
||||
|
||||
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_impl(
|
||||
bsk,
|
||||
extension_factor,
|
||||
accumulator,
|
||||
&msed,
|
||||
lwe_out.as_mut_view(),
|
||||
fft,
|
||||
stack,
|
||||
thread_buffers,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_impl<
|
||||
Scalar,
|
||||
KeyCont,
|
||||
LutCont,
|
||||
MsedLwe,
|
||||
>(
|
||||
bsk: &FourierLweBootstrapKey<KeyCont>,
|
||||
extension_factor: LweBootstrapExtensionFactor,
|
||||
input_lut: &GlweCiphertext<LutCont>,
|
||||
msed_lwe: &MsedLwe,
|
||||
mut lwe_out: LweCiphertext<&mut [Scalar]>,
|
||||
fft: FftView<'_>,
|
||||
stack: &mut PodStack,
|
||||
thread_stacks: &mut [&mut PodStack],
|
||||
) where
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
KeyCont: Container<Element = c64> + Sync,
|
||||
LutCont: Container<Element = Scalar>,
|
||||
MsedLwe: ModulusSwitchedLweCiphertext<usize> + Sync,
|
||||
{
|
||||
assert_eq!(thread_stacks.len(), extension_factor.get());
|
||||
|
||||
let lut_poly_size = input_lut.polynomial_size();
|
||||
let ciphertext_modulus = input_lut.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
assert_eq!(
|
||||
bsk.polynomial_size().0 * extension_factor.get(),
|
||||
lut_poly_size.0
|
||||
);
|
||||
assert_eq!(
|
||||
msed_lwe.log_modulus(),
|
||||
lut_poly_size.to_blind_rotation_input_modulus_log()
|
||||
);
|
||||
|
||||
let monomial_degree = MonomialDegree(msed_lwe.body());
|
||||
|
||||
let (lut_data, stack) = stack.make_aligned_raw(input_lut.as_ref().len(), CACHELINE_ALIGN);
|
||||
|
||||
let mut lut = GlweCiphertext::from_container(
|
||||
&mut *lut_data,
|
||||
input_lut.polynomial_size(),
|
||||
input_lut.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
lut.as_mut_polynomial_list()
|
||||
.iter_mut()
|
||||
.zip(input_lut.as_polynomial_list().iter())
|
||||
.for_each(|(mut dst_poly, src_poly)| {
|
||||
polynomial_wrapping_monic_monomial_div(&mut dst_poly, &src_poly, monomial_degree)
|
||||
});
|
||||
|
||||
// Remove mutability to make things more readable and use 0/1 notation to refer to computation
|
||||
// buffers
|
||||
let ct0 = lut;
|
||||
|
||||
let mut split_ct0 = Vec::with_capacity(extension_factor.get());
|
||||
let mut split_ct1 = Vec::with_capacity(extension_factor.get());
|
||||
|
||||
let substack0 = {
|
||||
let mut current_stack = stack;
|
||||
for _ in 0..extension_factor.get() {
|
||||
let (glwe_cont, substack) = current_stack.make_aligned_raw::<Scalar>(
|
||||
bsk.glwe_size().0 * bsk.polynomial_size().0,
|
||||
CACHELINE_ALIGN,
|
||||
);
|
||||
split_ct0.push(GlweCiphertext::from_container(
|
||||
glwe_cont,
|
||||
bsk.polynomial_size(),
|
||||
ciphertext_modulus,
|
||||
));
|
||||
current_stack = substack;
|
||||
}
|
||||
current_stack
|
||||
};
|
||||
|
||||
let _substack1 = {
|
||||
let mut current_stack = substack0;
|
||||
for _ in 0..extension_factor.get() {
|
||||
let (glwe_cont, substack) = current_stack.make_aligned_raw::<Scalar>(
|
||||
bsk.glwe_size().0 * bsk.polynomial_size().0,
|
||||
CACHELINE_ALIGN,
|
||||
);
|
||||
split_ct1.push(GlweCiphertext::from_container(
|
||||
glwe_cont,
|
||||
bsk.polynomial_size(),
|
||||
ciphertext_modulus,
|
||||
));
|
||||
current_stack = substack;
|
||||
}
|
||||
current_stack
|
||||
};
|
||||
|
||||
unchecked_split_extended_lut_into_small_luts(&ct0, &mut split_ct0, extension_factor);
|
||||
|
||||
let thread_split_ct0 = UnsafeSlice::new(&mut split_ct0);
|
||||
let thread_split_ct1 = UnsafeSlice::new(&mut split_ct1);
|
||||
|
||||
use std::sync::Barrier;
|
||||
let barrier = Barrier::new(extension_factor.get());
|
||||
|
||||
std::thread::scope(|s| {
|
||||
let thread_processing = |id: usize, stack: &mut PodStack| {
|
||||
// ===== Setup thread local resources =====
|
||||
let ct_dst_idx = id;
|
||||
let extension_factor_rem_mask = extension_factor.get() - 1;
|
||||
|
||||
let (diff_dyn_array, stack) = stack.make_aligned_raw::<Scalar>(
|
||||
bsk.glwe_size().0 * bsk.polynomial_size().0,
|
||||
CACHELINE_ALIGN,
|
||||
);
|
||||
|
||||
let mut diff_buffer = GlweCiphertext::from_container(
|
||||
diff_dyn_array,
|
||||
bsk.polynomial_size(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
// ===== Perform the subpart of the rotation the thread is responsible for =====
|
||||
for (mask_idx, (monomial_degree, ggsw)) in msed_lwe
|
||||
.mask()
|
||||
.map(MonomialDegree)
|
||||
.zip(bsk.as_view().into_ggsw_iter())
|
||||
.enumerate()
|
||||
{
|
||||
// Update the LUT the current thread looks at simulating the rotation in the
|
||||
// extended LUT
|
||||
//
|
||||
// One thread is responsible of a destination LUT index at all times (alternating
|
||||
// between two work buffers), we simulate the rotation by seeing which "source" LUT
|
||||
// would end up in the index the thread is responsible for, this avoids memory
|
||||
// copies
|
||||
//
|
||||
// The index mapping combined with a well chosen monomial multiplication simulates
|
||||
// the rotation in the extended LUT from rotations of the small split LUTs
|
||||
let ct_src_idx =
|
||||
(ct_dst_idx.wrapping_sub(monomial_degree.0)) & extension_factor_rem_mask;
|
||||
|
||||
// The well chosen monomial degree allowing to complete the simulated rotation in
|
||||
// the extended LUT, see function comment for the fomula
|
||||
let small_monomial_degree =
|
||||
small_lut_monomial_degree_from_extended_lut_monomial_degree(
|
||||
monomial_degree,
|
||||
extension_factor,
|
||||
ct_dst_idx,
|
||||
);
|
||||
|
||||
let rotated_buffer = {
|
||||
let (src_to_rotate, dst_rotated, src_unrotated) = if (mask_idx % 2) == 0 {
|
||||
unsafe {
|
||||
(
|
||||
thread_split_ct0.read(ct_src_idx),
|
||||
thread_split_ct1.write(ct_dst_idx),
|
||||
thread_split_ct0.read(ct_dst_idx),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
unsafe {
|
||||
(
|
||||
thread_split_ct1.read(ct_src_idx),
|
||||
thread_split_ct0.write(ct_dst_idx),
|
||||
thread_split_ct1.read(ct_dst_idx),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Prepare the destination for the ext prod by copying the unrotated
|
||||
// accumulator there
|
||||
dst_rotated.as_mut().copy_from_slice(src_unrotated.as_ref());
|
||||
|
||||
for (mut diff_poly, src_to_rotate_poly) in izip!(
|
||||
diff_buffer.as_mut_polynomial_list().iter_mut(),
|
||||
src_to_rotate.as_polynomial_list().iter(),
|
||||
) {
|
||||
// Rotate the lut that ends up in our slot and add to the
|
||||
// destination
|
||||
// This is computing Rot(ACCj)
|
||||
polynomial_wrapping_monic_monomial_mul(
|
||||
&mut diff_poly,
|
||||
&src_to_rotate_poly,
|
||||
small_monomial_degree,
|
||||
);
|
||||
}
|
||||
|
||||
// This is computing Rot(ACCj) - ACCj
|
||||
glwe_ciphertext_sub_assign(&mut diff_buffer, src_unrotated);
|
||||
|
||||
dst_rotated
|
||||
};
|
||||
|
||||
// ACCj ← BSKi x (Rot(ACCj) - ACCj) + ACCj
|
||||
add_external_product_assign(
|
||||
rotated_buffer.as_mut_view(),
|
||||
ggsw,
|
||||
diff_buffer.as_view(),
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
|
||||
let _ = barrier.wait();
|
||||
}
|
||||
};
|
||||
|
||||
#[allow(clippy::needless_collect)]
|
||||
let threads: Vec<_> = thread_stacks
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(id, stack)| s.spawn(move || thread_processing(id, stack)))
|
||||
.collect();
|
||||
|
||||
for t in threads {
|
||||
t.join().unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let lwe_dimension = bsk.input_lwe_dimension().0;
|
||||
let buffer_to_use = if lwe_dimension.is_multiple_of(2) {
|
||||
split_ct0
|
||||
} else {
|
||||
split_ct1
|
||||
};
|
||||
|
||||
let mut lut_0 = buffer_to_use.into_iter().next().unwrap();
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
// fill the MSBs and leave the LSBs empty, this usage of the signed decomposer allows to
|
||||
// round while keeping the data in the MSBs
|
||||
let signed_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(ciphertext_modulus.get_custom_modulus().ilog2() as usize),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
lut_0
|
||||
.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = signed_decomposer.closest_representable(*x));
|
||||
}
|
||||
|
||||
let split_accumulator = GlweCiphertext::from_container(
|
||||
lut_0.as_ref().to_vec(),
|
||||
lut_0.polynomial_size(),
|
||||
lut_0.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(&split_accumulator, &mut lwe_out, MonomialDegree(0));
|
||||
}
|
||||
|
||||
/// Return the required memory for
|
||||
/// [`extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized`].
|
||||
pub fn extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement<
|
||||
OutputScalar,
|
||||
>(
|
||||
glwe_size: GlweSize,
|
||||
small_polynomial_size: PolynomialSize,
|
||||
extension_factor: LweBootstrapExtensionFactor,
|
||||
fft: FftView<'_>,
|
||||
) -> StackReq {
|
||||
let local_accumulator_req = StackReq::new_aligned::<OutputScalar>(
|
||||
glwe_size.0 * small_polynomial_size.0 * extension_factor.get(),
|
||||
CACHELINE_ALIGN,
|
||||
);
|
||||
|
||||
// split ct0 allocation
|
||||
// we need (k + 1) * N * 2^nu
|
||||
let split_ct0_req = StackReq::new_aligned::<OutputScalar>(
|
||||
glwe_size.0 * small_polynomial_size.0 * extension_factor.get(),
|
||||
CACHELINE_ALIGN,
|
||||
);
|
||||
|
||||
// split ct1 allocation
|
||||
// we need (k + 1) * N * 2^nu
|
||||
let split_ct1_req = StackReq::new_aligned::<OutputScalar>(
|
||||
glwe_size.0 * small_polynomial_size.0 * extension_factor.get(),
|
||||
CACHELINE_ALIGN,
|
||||
);
|
||||
|
||||
StackReq::all_of(&[
|
||||
local_accumulator_req,
|
||||
split_ct0_req,
|
||||
split_ct1_req,
|
||||
// diff_buffer allocation
|
||||
// we need (k + 1) * N
|
||||
StackReq::new_aligned::<OutputScalar>(
|
||||
glwe_size.0 * small_polynomial_size.0,
|
||||
CACHELINE_ALIGN,
|
||||
),
|
||||
// external product
|
||||
add_external_product_assign_scratch::<OutputScalar>(glwe_size, small_polynomial_size, fft),
|
||||
])
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn sorted_extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized<
|
||||
Scalar,
|
||||
KeyCont,
|
||||
OutputCont,
|
||||
InputCont,
|
||||
AccCont,
|
||||
>(
|
||||
bsk: &FourierLweBootstrapKey<KeyCont>,
|
||||
lwe_out: &mut LweCiphertext<OutputCont>,
|
||||
lwe_in: &LweCiphertext<InputCont>,
|
||||
accumulator: &GlweCiphertext<AccCont>,
|
||||
extension_factor: LweBootstrapExtensionFactor,
|
||||
shortcut_coeff_count: LweExtendedBootstrapShortcutCoeffCount,
|
||||
fft: FftView<'_>,
|
||||
stack: &mut PodStack,
|
||||
thread_buffers: &mut [&mut PodStack],
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
KeyCont: Container<Element = c64> + Sync,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
AccCont: Container<Element = Scalar>,
|
||||
{
|
||||
// extension factor == 1 means classic bootstrap which is already optimized
|
||||
if extension_factor.0 == 1 {
|
||||
return bsk.as_view().bootstrap(
|
||||
lwe_out.as_mut_view(),
|
||||
lwe_in.as_view(),
|
||||
accumulator.as_view(),
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
lwe_out.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
bsk.polynomial_size().0 * extension_factor.0,
|
||||
accumulator.polynomial_size().0
|
||||
);
|
||||
|
||||
// TODO ? use only a split accumulator and an assign primitive ?
|
||||
let split_accumulator = sorted_extended_blind_rotate_mem_optimized_parallelized(
|
||||
bsk,
|
||||
accumulator,
|
||||
lwe_in.as_view(),
|
||||
extension_factor,
|
||||
shortcut_coeff_count,
|
||||
fft,
|
||||
stack,
|
||||
thread_buffers,
|
||||
);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(&split_accumulator, lwe_out, MonomialDegree(0));
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn sorted_extended_blind_rotate_mem_optimized_parallelized<Scalar, KeyCont, LutCont>(
|
||||
bsk: &FourierLweBootstrapKey<KeyCont>,
|
||||
input_lut: &GlweCiphertext<LutCont>,
|
||||
input_lwe: LweCiphertextView<'_, Scalar>,
|
||||
extension_factor: LweBootstrapExtensionFactor,
|
||||
shortcut_coeff_count: LweExtendedBootstrapShortcutCoeffCount,
|
||||
fft: FftView<'_>,
|
||||
stack: &mut PodStack,
|
||||
thread_stacks: &mut [&mut PodStack],
|
||||
) -> GlweCiphertextOwned<Scalar>
|
||||
where
|
||||
Scalar: UnsignedTorus + CastInto<usize>,
|
||||
KeyCont: Container<Element = c64> + Sync,
|
||||
LutCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(thread_stacks.len(), extension_factor.0);
|
||||
|
||||
let (lwe_mask, lwe_body) = input_lwe.get_mask_and_body();
|
||||
|
||||
let lut_poly_size = input_lut.polynomial_size();
|
||||
let br_input_modulus_log = lut_poly_size.to_blind_rotation_input_modulus_log();
|
||||
let ciphertext_modulus = input_lut.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
assert_eq!(
|
||||
bsk.polynomial_size().0 * extension_factor.0,
|
||||
lut_poly_size.0
|
||||
);
|
||||
let monomial_degree = MonomialDegree(modulus_switch(
|
||||
(*lwe_body.data).cast_into(),
|
||||
// This one should be the extended polynomial size
|
||||
br_input_modulus_log,
|
||||
));
|
||||
|
||||
let (lut_data, stack) = stack.make_aligned_raw(input_lut.as_ref().len(), CACHELINE_ALIGN);
|
||||
|
||||
let mut lut = GlweCiphertext::from_container(
|
||||
&mut *lut_data,
|
||||
input_lut.polynomial_size(),
|
||||
input_lut.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
lut.as_mut_polynomial_list()
|
||||
.iter_mut()
|
||||
.zip(input_lut.as_polynomial_list().iter())
|
||||
.for_each(|(mut dst_poly, src_poly)| {
|
||||
polynomial_wrapping_monic_monomial_div(&mut dst_poly, &src_poly, monomial_degree)
|
||||
});
|
||||
|
||||
// Remove mutability to make things more readable and use 0/1 notation to refer to computation
|
||||
// buffers
|
||||
let ct0 = lut;
|
||||
|
||||
let mut split_ct0 = Vec::with_capacity(extension_factor.0);
|
||||
let mut split_ct1 = Vec::with_capacity(extension_factor.0);
|
||||
|
||||
let substack0 = {
|
||||
let mut current_stack = stack;
|
||||
for _ in 0..extension_factor.0 {
|
||||
let (glwe_cont, substack) = current_stack.make_aligned_raw::<Scalar>(
|
||||
bsk.glwe_size().0 * bsk.polynomial_size().0,
|
||||
CACHELINE_ALIGN,
|
||||
);
|
||||
split_ct0.push(GlweCiphertext::from_container(
|
||||
glwe_cont,
|
||||
bsk.polynomial_size(),
|
||||
ct0.ciphertext_modulus(),
|
||||
));
|
||||
current_stack = substack;
|
||||
}
|
||||
current_stack
|
||||
};
|
||||
|
||||
let _substack1 = {
|
||||
let mut current_stack = substack0;
|
||||
for _ in 0..extension_factor.0 {
|
||||
let (glwe_cont, substack) = current_stack.make_aligned_raw::<Scalar>(
|
||||
bsk.glwe_size().0 * bsk.polynomial_size().0,
|
||||
CACHELINE_ALIGN,
|
||||
);
|
||||
split_ct1.push(GlweCiphertext::from_container(
|
||||
glwe_cont,
|
||||
bsk.polynomial_size(),
|
||||
ct0.ciphertext_modulus(),
|
||||
));
|
||||
current_stack = substack;
|
||||
}
|
||||
current_stack
|
||||
};
|
||||
|
||||
// Split the LUT into small LUTs
|
||||
for (idx, coeff) in ct0.as_ref().iter().copied().enumerate() {
|
||||
let dst_lut = &mut split_ct0[idx % extension_factor.0];
|
||||
dst_lut.as_mut()[idx / extension_factor.0] = coeff;
|
||||
}
|
||||
|
||||
let thread_split_ct0 = UnsafeSlice::new(&mut split_ct0);
|
||||
let thread_split_ct1 = UnsafeSlice::new(&mut split_ct1);
|
||||
|
||||
use std::sync::Barrier;
|
||||
let extension_factor_log2 = extension_factor.0.ilog2();
|
||||
let extension_factor_rem_mask = extension_factor.0 - 1;
|
||||
let congruence_classes_count = extension_factor_log2 as usize + 1;
|
||||
let mut congruence_classes: Vec<_> = (0..congruence_classes_count)
|
||||
.map(|idx| (vec![], Barrier::new(extension_factor.0 / (1 << idx))))
|
||||
.collect();
|
||||
|
||||
let mut shortcut_destinations = vec![vec![]; congruence_classes_count];
|
||||
|
||||
'outer: for (mask_idx, &mask_element) in lwe_mask.as_ref().iter().enumerate() {
|
||||
let mod_switched: usize = modulus_switch(mask_element.cast_into(), br_input_modulus_log);
|
||||
|
||||
if mod_switched % 2 == 1 {
|
||||
let modulus_switch_log = (lut_poly_size.0 * 2).ilog2() as usize;
|
||||
let rounding_bit = (mask_element >> (Scalar::BITS - modulus_switch_log)) & Scalar::ONE;
|
||||
let altered_mod_switch = if rounding_bit == Scalar::ZERO {
|
||||
mod_switched.wrapping_add(1) % (1 << modulus_switch_log)
|
||||
} else {
|
||||
mod_switched.wrapping_sub(1) % (1 << modulus_switch_log)
|
||||
};
|
||||
// for mod_idx in 1..congruence_classes.len() - 1
|
||||
for (mod_idx, shortcut_dest) in shortcut_destinations
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.take(congruence_classes.len() - 1)
|
||||
.skip(1)
|
||||
{
|
||||
let mod_power = mod_idx + 1;
|
||||
let modulus: usize = (Scalar::ONE << mod_power).cast_into();
|
||||
let expected_remainder = modulus >> 1;
|
||||
|
||||
if altered_mod_switch % modulus == expected_remainder {
|
||||
shortcut_dest.push((mask_idx, (mod_switched, altered_mod_switch)));
|
||||
continue 'outer;
|
||||
}
|
||||
}
|
||||
shortcut_destinations[congruence_classes_count - 1]
|
||||
.push((mask_idx, (mod_switched, altered_mod_switch)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// println!();
|
||||
// println!("{mod_switched:064b}");
|
||||
|
||||
for mod_idx in 1..congruence_classes.len() - 1 {
|
||||
let mod_power = mod_idx + 1;
|
||||
let modulus: usize = (Scalar::ONE << mod_power).cast_into();
|
||||
// println!("modulus={modulus}");
|
||||
let expected_remainder = modulus >> 1;
|
||||
// println!("expected_remainder={expected_remainder}");
|
||||
|
||||
if mod_switched % modulus == expected_remainder {
|
||||
// println!("In class {expected_remainder}");
|
||||
congruence_classes[mod_idx].0.push((mask_idx, mod_switched));
|
||||
continue 'outer;
|
||||
}
|
||||
}
|
||||
// println!("In other class");
|
||||
congruence_classes[congruence_classes_count - 1]
|
||||
.0
|
||||
.push((mask_idx, mod_switched));
|
||||
}
|
||||
|
||||
let mut shortcut_remaining = shortcut_coeff_count.0;
|
||||
|
||||
for (shortcut_class_idx, shortcut_class) in shortcut_destinations.iter().enumerate().rev() {
|
||||
for (mask_idx, (mod_switched, altered_mod_switch)) in shortcut_class.iter().copied() {
|
||||
if shortcut_remaining > 0 {
|
||||
shortcut_remaining -= 1;
|
||||
congruence_classes[shortcut_class_idx]
|
||||
.0
|
||||
.push((mask_idx, altered_mod_switch));
|
||||
} else {
|
||||
congruence_classes[0].0.push((mask_idx, mod_switched));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let gathered_dim = congruence_classes.iter().map(|x| x.0.len()).sum::<usize>();
|
||||
assert_eq!(gathered_dim, lwe_mask.as_ref().len());
|
||||
|
||||
std::thread::scope(|s| {
|
||||
let thread_processing = |id: usize, stack: &mut PodStack| {
|
||||
let (diff_dyn_array, stack) = stack.make_aligned_raw::<Scalar>(
|
||||
bsk.glwe_size().0 * bsk.polynomial_size().0,
|
||||
CACHELINE_ALIGN,
|
||||
);
|
||||
|
||||
let mut diff_buffer = GlweCiphertext::from_container(
|
||||
diff_dyn_array,
|
||||
bsk.polynomial_size(),
|
||||
ct0.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let ggsw_vec = bsk.as_view().into_ggsw_iter().collect::<Vec<_>>();
|
||||
|
||||
// let mut skipped = 0;
|
||||
|
||||
let mut overall_loop_idx = 0;
|
||||
|
||||
for (congruence_class_idx, (mask_indices, barrier)) in
|
||||
congruence_classes.iter().enumerate()
|
||||
{
|
||||
let ct_dst_idx = id;
|
||||
let should_process = ct_dst_idx.is_multiple_of(1 << congruence_class_idx);
|
||||
if !should_process {
|
||||
return;
|
||||
}
|
||||
//todo!("get correct barrier to wait");
|
||||
for (mask_idx, monomial_degree) in mask_indices.iter().copied() {
|
||||
// println!("mask_element: {mask_element:064b}");
|
||||
let ggsw = ggsw_vec[mask_idx];
|
||||
let monomial_degree = MonomialDegree(monomial_degree);
|
||||
|
||||
//todo use id of thread
|
||||
// Update the lut we look at simulating the rotation in the larger lut
|
||||
let ct_src_idx =
|
||||
(ct_dst_idx.wrapping_sub(monomial_degree.0)) & extension_factor_rem_mask;
|
||||
// Compute the end of the rotation
|
||||
// N' = 2^nu * N
|
||||
// new_lut_idx = (ai + old_lut_idx) % 2^nu
|
||||
// (2^nu + (ai % 2N') - 1 - new_lut_idx)/2^nu a l'air de marcher pour x X^ai
|
||||
// monomial degree = mod switch(ai) already % 2N'
|
||||
let small_monomial_degree = MonomialDegree(
|
||||
(extension_factor.0 + monomial_degree.0 - 1 - ct_dst_idx)
|
||||
>> extension_factor_log2,
|
||||
);
|
||||
let rotated_buffer = {
|
||||
let (src_to_rotate, dst_rotated, src_unrotated) =
|
||||
if (overall_loop_idx % 2) == 0 {
|
||||
unsafe {
|
||||
(
|
||||
thread_split_ct0.read(ct_src_idx),
|
||||
&mut *thread_split_ct1.write(ct_dst_idx),
|
||||
thread_split_ct0.read(ct_dst_idx),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
unsafe {
|
||||
(
|
||||
thread_split_ct1.read(ct_src_idx),
|
||||
&mut *thread_split_ct0.write(ct_dst_idx),
|
||||
thread_split_ct1.read(ct_dst_idx),
|
||||
)
|
||||
}
|
||||
};
|
||||
// Prepare the destination for the ext prod by copying the unrotated
|
||||
// accumulator there
|
||||
dst_rotated.as_mut().copy_from_slice(src_unrotated.as_ref());
|
||||
for (mut diff_poly, src_to_rotate_poly) in izip!(
|
||||
diff_buffer.as_mut_polynomial_list().iter_mut(),
|
||||
src_to_rotate.as_polynomial_list().iter(),
|
||||
) {
|
||||
// Rotate the lut that ends up in our slot and add to the
|
||||
// destination This is computing
|
||||
// Rot(ACCj)
|
||||
polynomial_wrapping_monic_monomial_mul(
|
||||
&mut diff_poly,
|
||||
&src_to_rotate_poly,
|
||||
small_monomial_degree,
|
||||
);
|
||||
}
|
||||
|
||||
// This is computing Rot(ACCj) - ACCj
|
||||
glwe_ciphertext_sub_assign(&mut diff_buffer, src_unrotated);
|
||||
|
||||
dst_rotated
|
||||
};
|
||||
|
||||
// ACCj ← BSKi x (Rot(ACCj) - ACCj) + ACCj
|
||||
add_external_product_assign(
|
||||
rotated_buffer.as_mut_view(),
|
||||
ggsw,
|
||||
diff_buffer.as_view(),
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
let _ = barrier.wait();
|
||||
overall_loop_idx += 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#[allow(clippy::needless_collect)]
|
||||
let threads: Vec<_> = thread_stacks
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(id, stack)| s.spawn(move || thread_processing(id, stack)))
|
||||
.collect();
|
||||
|
||||
for t in threads {
|
||||
t.join().unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let lwe_dimension = bsk.input_lwe_dimension().0;
|
||||
let buffer_to_use = if lwe_dimension.is_multiple_of(2) {
|
||||
split_ct0
|
||||
} else {
|
||||
split_ct1
|
||||
};
|
||||
|
||||
let mut lut_0 = buffer_to_use.into_iter().next().unwrap();
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
// fill the MSBs and leave the LSBs empty, this usage of the signed decomposer allows to
|
||||
// round while keeping the data in the MSBs
|
||||
let signed_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(ciphertext_modulus.get_custom_modulus().ilog2() as usize),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
lut_0
|
||||
.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = signed_decomposer.closest_representable(*x));
|
||||
}
|
||||
|
||||
GlweCiphertext::from_container(
|
||||
lut_0.as_ref().to_vec(),
|
||||
lut_0.polynomial_size(),
|
||||
lut_0.ciphertext_modulus(),
|
||||
)
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod glwe_fast_keyswitch;
|
||||
pub mod glwe_partial_sample_extraction;
|
||||
pub mod lwe_extended_programmable_bootstrapping;
|
||||
pub mod lwe_shrinking_keyswitch;
|
||||
pub mod lwe_shrinking_keyswitch_key_generation;
|
||||
pub mod partial_glwe_secret_key_generation;
|
||||
@@ -10,6 +11,7 @@ pub mod shared_lwe_secret_key_generation;
|
||||
|
||||
pub use glwe_fast_keyswitch::*;
|
||||
pub use glwe_partial_sample_extraction::*;
|
||||
pub use lwe_extended_programmable_bootstrapping::*;
|
||||
pub use lwe_shrinking_keyswitch::*;
|
||||
pub use lwe_shrinking_keyswitch_key_generation::*;
|
||||
pub use partial_glwe_secret_key_generation::*;
|
||||
|
||||
@@ -0,0 +1,499 @@
|
||||
use super::*;
|
||||
use crate::core_crypto::algorithms::glwe_secret_key_generation::allocate_and_generate_new_binary_glwe_secret_key;
|
||||
use crate::core_crypto::algorithms::lwe_bootstrap_key_conversion::par_convert_standard_lwe_bootstrap_key_to_fourier;
|
||||
use crate::core_crypto::algorithms::lwe_bootstrap_key_generation::par_allocate_and_generate_new_lwe_bootstrap_key;
|
||||
use crate::core_crypto::algorithms::lwe_encryption::{
|
||||
allocate_and_encrypt_new_lwe_ciphertext, decrypt_lwe_ciphertext,
|
||||
};
|
||||
use crate::core_crypto::algorithms::lwe_programmable_bootstrapping::generate_programmable_bootstrap_glwe_lut;
|
||||
use crate::core_crypto::algorithms::lwe_secret_key_generation::allocate_and_generate_new_binary_lwe_secret_key;
|
||||
use crate::core_crypto::algorithms::misc::check_encrypted_content_respects_mod;
|
||||
use crate::core_crypto::algorithms::polynomial_algorithms;
|
||||
use crate::core_crypto::commons::computation_buffers::ComputationBuffers;
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
CarrySpace, CiphertextModulus, CleartextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
DynamicDistribution, EncryptionKeyChoice, GlweDimension, GlweSize, LweDimension, MaxNorm2,
|
||||
MessageSpace, MonomialDegree, PolynomialSize, StandardDev,
|
||||
};
|
||||
use crate::core_crypto::commons::traits::{CastInto, ContiguousEntityContainerMut};
|
||||
use crate::core_crypto::entities::glwe_ciphertext::GlweCiphertext;
|
||||
use crate::core_crypto::entities::lwe_ciphertext::LweCiphertext;
|
||||
use crate::core_crypto::entities::plaintext::Plaintext;
|
||||
use crate::core_crypto::experimental::algorithms::lwe_extended_programmable_bootstrapping::{
|
||||
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized,
|
||||
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement,
|
||||
small_lut_monomial_degree_from_extended_lut_monomial_degree,
|
||||
unchecked_split_extended_lut_into_small_luts,
|
||||
};
|
||||
use crate::core_crypto::experimental::commons::parameters::LweBootstrapExtensionFactor;
|
||||
use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::FourierLweBootstrapKey;
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::Fft;
|
||||
|
||||
// This test checks that the rotation in the extended polynomial ring is correctly performed by
|
||||
// first moving small polynomials around and then doing a monomial mul in each small polynomial by a
|
||||
// well chosen X^index
|
||||
// N' = 2^nu * N
|
||||
// new_lut_idx = (ai + old_lut_idx) % 2^nu
|
||||
// small_lut_monomial_index = (2^nu + (ai % 2N') - 1 - new_lut_idx)/2^nu
|
||||
// looks to work to multiply by X^ai
|
||||
#[test]
|
||||
fn test_monic_mul_split_eq() {
|
||||
use rand::Rng;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let glwe_size = GlweSize(2);
|
||||
let ciphertext_modulus = CiphertextModulus::<u64>::new_native();
|
||||
|
||||
for _ in 0..100 {
|
||||
let small_poly_size = PolynomialSize(1usize << rng.gen_range(5..=12));
|
||||
let extension_factor = LweBootstrapExtensionFactor::new(1 << rng.gen_range(1..=3));
|
||||
let polynomial_size = PolynomialSize(small_poly_size.0 * extension_factor.get());
|
||||
|
||||
let mut extended_lut =
|
||||
GlweCiphertext::new(0u64, glwe_size, polynomial_size, ciphertext_modulus);
|
||||
extended_lut
|
||||
.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = rng.gen());
|
||||
|
||||
let extended_lut = extended_lut;
|
||||
|
||||
let mut small_luts =
|
||||
vec![
|
||||
GlweCiphertext::new(0u64, glwe_size, small_poly_size, ciphertext_modulus);
|
||||
extension_factor.get()
|
||||
];
|
||||
|
||||
unchecked_split_extended_lut_into_small_luts(
|
||||
&extended_lut,
|
||||
&mut small_luts,
|
||||
extension_factor,
|
||||
);
|
||||
let ref_small_luts = small_luts;
|
||||
|
||||
for _ in 0..1000 {
|
||||
// Modulo 2*N' to mimic the modulus switch for the extended N' = 2^nu * N polynomial
|
||||
// size
|
||||
let monomial_degree = MonomialDegree(rng.gen::<usize>() % (polynomial_size.0 * 2));
|
||||
|
||||
// Compute the reference for the rotation results as split luts
|
||||
let ref_rotated_lut = {
|
||||
let mut ref_rotated_lut = extended_lut.clone();
|
||||
for mut polynomial_to_rotate in ref_rotated_lut.as_mut_polynomial_list().iter_mut()
|
||||
{
|
||||
polynomial_algorithms::polynomial_wrapping_monic_monomial_mul_assign(
|
||||
&mut polynomial_to_rotate,
|
||||
monomial_degree,
|
||||
);
|
||||
}
|
||||
|
||||
let mut ref_small_rotated_luts =
|
||||
vec![
|
||||
GlweCiphertext::new(0u64, glwe_size, small_poly_size, ciphertext_modulus);
|
||||
extension_factor.get()
|
||||
];
|
||||
|
||||
unchecked_split_extended_lut_into_small_luts(
|
||||
&ref_rotated_lut,
|
||||
&mut ref_small_rotated_luts,
|
||||
extension_factor,
|
||||
);
|
||||
|
||||
ref_rotated_lut
|
||||
};
|
||||
|
||||
// Compute the rotation results with the trick
|
||||
let small_luts_rotated_with_trick = {
|
||||
// Copy the unmodified reference extended LUT as small luts
|
||||
let mut small_luts_rotated_with_trick = ref_small_luts.clone();
|
||||
|
||||
// Rotate the lookup tables by the right amount to start simulating the extended LUT
|
||||
// rotation
|
||||
small_luts_rotated_with_trick
|
||||
.rotate_right(monomial_degree.0 % extension_factor.get());
|
||||
|
||||
// Complete the rotation with the monomial degree "trick" for small LUTs, using a
|
||||
// well chosen monomial degree for each small LUT rotation
|
||||
for (lut_idx, small_lut) in small_luts_rotated_with_trick.iter_mut().enumerate() {
|
||||
let small_monomial_degree =
|
||||
small_lut_monomial_degree_from_extended_lut_monomial_degree(
|
||||
monomial_degree,
|
||||
extension_factor,
|
||||
lut_idx,
|
||||
);
|
||||
for mut polynmial_to_rotate in small_lut.as_mut_polynomial_list().iter_mut() {
|
||||
polynomial_algorithms::polynomial_wrapping_monic_monomial_mul_assign(
|
||||
&mut polynmial_to_rotate,
|
||||
small_monomial_degree,
|
||||
)
|
||||
}
|
||||
}
|
||||
small_luts_rotated_with_trick
|
||||
};
|
||||
|
||||
// Verify our formulas work the way we expect by comparing to the reference not using
|
||||
// the trick
|
||||
assert_eq!(ref_small_rotated_luts, small_luts_rotated_with_trick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct ExtendedPBSTestParameters {
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
extension_factor: LweBootstrapExtensionFactor,
|
||||
lwe_noise_distribution: DynamicDistribution<u64>,
|
||||
glwe_noise_distribution: DynamicDistribution<u64>,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
ks_level: DecompositionLevelCount,
|
||||
message_modulus: CleartextModulus<MessageSpace>,
|
||||
carry_modulus: CleartextModulus<CarrySpace>,
|
||||
max_norm2: MaxNorm2,
|
||||
log2_p_fail: f64,
|
||||
ciphertext_modulus: CiphertextModulus<u64>,
|
||||
encryption_key_choice: EncryptionKeyChoice,
|
||||
}
|
||||
|
||||
// p-fail = 2^-128.147, algorithmic cost ~ 67456140, 2-norm = 5, extension factor = 16,
|
||||
pub const TEST_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128: ExtendedPBSTestParameters =
|
||||
ExtendedPBSTestParameters {
|
||||
lwe_dimension: LweDimension(884),
|
||||
glwe_dimension: GlweDimension(4),
|
||||
polynomial_size: PolynomialSize(512),
|
||||
extension_factor: LweBootstrapExtensionFactor::new(16),
|
||||
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
|
||||
1.4999005934396873e-06,
|
||||
)),
|
||||
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
|
||||
2.845267479601915e-15,
|
||||
)),
|
||||
pbs_base_log: DecompositionBaseLog(23),
|
||||
pbs_level: DecompositionLevelCount(1),
|
||||
ks_base_log: DecompositionBaseLog(5),
|
||||
ks_level: DecompositionLevelCount(3),
|
||||
message_modulus: CleartextModulus::new(4),
|
||||
carry_modulus: CleartextModulus::new(4),
|
||||
max_norm2: MaxNorm2(5f64),
|
||||
log2_p_fail: -128.0,
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
encryption_key_choice: EncryptionKeyChoice::Big,
|
||||
};
|
||||
|
||||
fn lwe_encrypt_extended_pbs_decrypt(params: ExtendedPBSTestParameters) {
|
||||
let lwe_dimension = params.lwe_dimension;
|
||||
let lwe_noise_distribution = params.lwe_noise_distribution;
|
||||
let message_modulus = params.message_modulus.0;
|
||||
let carry_modulus = params.carry_modulus.0;
|
||||
let plaintext_modulus = message_modulus * carry_modulus;
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let extension_factor = params.extension_factor;
|
||||
let base_polynomial_size = params.polynomial_size;
|
||||
let extended_polynomial_size = PolynomialSize(base_polynomial_size.0 * extension_factor.get());
|
||||
let glwe_noise_distribution = params.glwe_noise_distribution;
|
||||
let pbs_base_log = params.pbs_base_log;
|
||||
let pbs_level_count = params.pbs_level;
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
let f = |x: u64| x;
|
||||
|
||||
let delta: u64 = encoding_with_padding / plaintext_modulus;
|
||||
let mut msg = plaintext_modulus;
|
||||
|
||||
let accumulator = generate_programmable_bootstrap_glwe_lut(
|
||||
extended_polynomial_size,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
plaintext_modulus.cast_into(),
|
||||
ciphertext_modulus,
|
||||
delta,
|
||||
f,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&accumulator,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
lwe_dimension,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
|
||||
let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dimension,
|
||||
base_polynomial_size,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
|
||||
let output_lwe_secret_key = output_glwe_secret_key.as_lwe_secret_key();
|
||||
|
||||
let bsk = par_allocate_and_generate_new_lwe_bootstrap_key(
|
||||
&input_lwe_secret_key,
|
||||
&output_glwe_secret_key,
|
||||
pbs_base_log,
|
||||
pbs_level_count,
|
||||
glwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut fbsk = FourierLweBootstrapKey::new(
|
||||
bsk.input_lwe_dimension(),
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
);
|
||||
|
||||
par_convert_standard_lwe_bootstrap_key_to_fourier(&bsk, &mut fbsk);
|
||||
|
||||
let fft = Fft::new(base_polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
// TODO: have req for main thread and for workers ?
|
||||
use extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement as rq;
|
||||
|
||||
let requirement = rq::<u64>(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
base_polynomial_size,
|
||||
extension_factor,
|
||||
fft,
|
||||
)
|
||||
.unaligned_bytes_required();
|
||||
|
||||
buffers.resize(requirement);
|
||||
|
||||
let mut thread_buffers = Vec::with_capacity(extension_factor.get());
|
||||
for _ in 0..extension_factor.get() {
|
||||
let mut buffer = ComputationBuffers::new();
|
||||
buffer.resize(requirement);
|
||||
thread_buffers.push(buffer);
|
||||
}
|
||||
|
||||
let mut thread_stacks: Vec<_> = thread_buffers.iter_mut().map(|x| x.stack()).collect();
|
||||
|
||||
while msg != 0 {
|
||||
msg = msg.wrapping_sub(1);
|
||||
|
||||
for _ in 0..10 {
|
||||
let plaintext = Plaintext(msg * delta);
|
||||
|
||||
let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&input_lwe_secret_key,
|
||||
plaintext,
|
||||
lwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&lwe_ciphertext_in,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let mut out_pbs_ct = LweCiphertext::new(
|
||||
0,
|
||||
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized(
|
||||
&fbsk,
|
||||
&mut out_pbs_ct,
|
||||
&lwe_ciphertext_in,
|
||||
&accumulator,
|
||||
extension_factor,
|
||||
fft,
|
||||
buffers.stack(),
|
||||
&mut thread_stacks,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&out_pbs_ct,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let decrypted = decrypt_lwe_ciphertext(&output_lwe_secret_key, &out_pbs_ct);
|
||||
|
||||
let decoded = round_decode(decrypted.0, delta) % plaintext_modulus;
|
||||
|
||||
assert_eq!(decoded, f(msg));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(tarpaulin)]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parameterized_test!(lwe_encrypt_extended_pbs_decrypt {
|
||||
TEST_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128,
|
||||
});
|
||||
|
||||
|
||||
fn lwe_encrypt_sorted_extended_pbs_decrypt(params: ExtendedPBSTestParameters) {
|
||||
let lwe_dimension = params.lwe_dimension;
|
||||
let lwe_noise_distribution = params.lwe_noise_distribution;
|
||||
let message_modulus = params.message_modulus.0;
|
||||
let carry_modulus = params.carry_modulus.0;
|
||||
let plaintext_modulus = message_modulus * carry_modulus;
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let extension_factor = params.extension_factor;
|
||||
let base_polynomial_size = params.polynomial_size;
|
||||
let extended_polynomial_size = PolynomialSize(base_polynomial_size.0 * extension_factor.0);
|
||||
let glwe_noise_distribution = params.glwe_noise_distribution;
|
||||
let pbs_base_log = params.pbs_base_log;
|
||||
let pbs_level_count = params.pbs_level;
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
let f = |x: u64| x;
|
||||
|
||||
let delta: u64 = encoding_with_padding / plaintext_modulus;
|
||||
let mut msg = plaintext_modulus;
|
||||
|
||||
let accumulator = generate_programmable_bootstrap_glwe_lut(
|
||||
extended_polynomial_size,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
plaintext_modulus.cast_into(),
|
||||
ciphertext_modulus,
|
||||
delta,
|
||||
f,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&accumulator,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
lwe_dimension,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
|
||||
let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dimension,
|
||||
base_polynomial_size,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
|
||||
let output_lwe_secret_key = output_glwe_secret_key.as_lwe_secret_key();
|
||||
|
||||
let bsk = par_allocate_and_generate_new_lwe_bootstrap_key(
|
||||
&input_lwe_secret_key,
|
||||
&output_glwe_secret_key,
|
||||
pbs_base_log,
|
||||
pbs_level_count,
|
||||
glwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut fbsk = FourierLweBootstrapKey::new(
|
||||
bsk.input_lwe_dimension(),
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
);
|
||||
|
||||
par_convert_standard_lwe_bootstrap_key_to_fourier(&bsk, &mut fbsk);
|
||||
|
||||
let fft = Fft::new(base_polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
// TODO: have req for main thread and for workers ?
|
||||
use extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement as rq;
|
||||
|
||||
let requirement = rq::<u64>(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
base_polynomial_size,
|
||||
extension_factor,
|
||||
fft,
|
||||
)
|
||||
.unaligned_bytes_required();
|
||||
|
||||
buffers.resize(requirement);
|
||||
|
||||
let mut thread_buffers = Vec::with_capacity(extension_factor.0);
|
||||
for _ in 0..extension_factor.0 {
|
||||
let mut buffer = ComputationBuffers::new();
|
||||
buffer.resize(requirement);
|
||||
thread_buffers.push(buffer);
|
||||
}
|
||||
|
||||
let mut thread_stacks: Vec<_> = thread_buffers.iter_mut().map(|x| x.stack()).collect();
|
||||
|
||||
while msg != 0 {
|
||||
msg = msg.wrapping_sub(1);
|
||||
|
||||
for _ in 0..10 {
|
||||
let plaintext = Plaintext(msg * delta);
|
||||
|
||||
let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&input_lwe_secret_key,
|
||||
plaintext,
|
||||
lwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&lwe_ciphertext_in,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let mut out_pbs_ct = LweCiphertext::new(
|
||||
0,
|
||||
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
sorted_extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized(
|
||||
&fbsk,
|
||||
&mut out_pbs_ct,
|
||||
&lwe_ciphertext_in,
|
||||
&accumulator,
|
||||
extension_factor,
|
||||
fft,
|
||||
buffers.stack(),
|
||||
&mut thread_stacks,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&out_pbs_ct,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let decrypted = decrypt_lwe_ciphertext(&output_lwe_secret_key, &out_pbs_ct);
|
||||
|
||||
let decoded = round_decode(decrypted.0, delta) % plaintext_modulus;
|
||||
|
||||
assert_eq!(decoded, f(msg));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(tarpaulin)]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parameterized_test!(lwe_encrypt_sorted_extended_pbs_decrypt {
|
||||
TEST_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128,
|
||||
});
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::core_crypto::algorithms::test::*;
|
||||
use crate::core_crypto::experimental::prelude::*;
|
||||
|
||||
mod lwe_extended_programmable_bootstrapping;
|
||||
mod lwe_fast_keyswitch;
|
||||
mod lwe_stair_keyswitch;
|
||||
|
||||
@@ -43,3 +43,39 @@ impl crate::core_crypto::commons::parameters::LweDimension {
|
||||
LweSecretKeyUnsharedCoefCount(self.0 - shared_coef_count.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameter indicating by how much a LUT polynomial size is multiplied in the extended PBS
|
||||
/// setting.
|
||||
///
|
||||
/// The extended PBS simulates the rotation of a larger LUT using only smaller LUTs. This
|
||||
/// has nice noise properties in some cases. The extended LUT polynomial size is N' = tau * N where
|
||||
/// tau is a power of 2, usual notation is tau = 2^nu, so N' = 2^nu * N.
|
||||
///
|
||||
/// See [this paper](https://eprint.iacr.org/2025/2214.pdf).
|
||||
///
|
||||
/// Currently the extension factor needs to be a power of two to keep a compatible power of two for
|
||||
/// the extended LUT.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct LweBootstrapExtensionFactor(usize);
|
||||
|
||||
impl LweBootstrapExtensionFactor {
|
||||
pub const fn new(value: usize) -> Self {
|
||||
assert!(
|
||||
value > 1,
|
||||
"An LweBootstrapExtensionFactor <= 1 makes no sense."
|
||||
);
|
||||
assert!(
|
||||
value.is_power_of_two(),
|
||||
"LweBootstrapExtensionFactor needs to be a power of 2"
|
||||
);
|
||||
|
||||
Self(value)
|
||||
}
|
||||
|
||||
pub const fn get(&self) -> usize {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct LweExtendedBootstrapShortcutCoeffCount(pub usize);
|
||||
|
||||
@@ -1477,4 +1477,67 @@ pub(crate) mod test {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oprf_test_uniformity_bits_ci_run_filter() {
|
||||
let sample_count: usize = 100_000;
|
||||
|
||||
let p_value_limit: f64 = 0.000_01;
|
||||
|
||||
use crate::shortint::gen_keys;
|
||||
use crate::shortint::parameters::test_params::{
|
||||
TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
|
||||
};
|
||||
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
|
||||
for params in [
|
||||
ShortintParameterSet::from(
|
||||
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
|
||||
),
|
||||
ShortintParameterSet::from(PARAM_MESSAGE_2_CARRY_2_KS_PBS),
|
||||
ShortintParameterSet::from(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128),
|
||||
] {
|
||||
let (ck, sk) = gen_keys(params);
|
||||
let oprf_ck = OprfPrivateKey::new(&ck);
|
||||
let oprf_sk = OprfServerKey::new(&oprf_ck, &ck).unwrap();
|
||||
|
||||
let random_bits_per_block = sk.message_modulus.0.ilog2() as u64;
|
||||
|
||||
for random_bits_count in [3u64, 4] {
|
||||
let expected_num_blocks =
|
||||
random_bits_count.div_ceil(random_bits_per_block) as usize;
|
||||
|
||||
test_uniformity(
|
||||
sample_count,
|
||||
p_value_limit,
|
||||
1 << random_bits_count,
|
||||
|seed| {
|
||||
let seed = (seed as u128).to_le_bytes();
|
||||
let blocks = oprf_sk.generate_oblivious_pseudo_random_bits(
|
||||
seed.as_slice(),
|
||||
random_bits_count,
|
||||
&sk,
|
||||
);
|
||||
|
||||
let mut combined: u64 = 0;
|
||||
let mut shift = 0u64;
|
||||
for (i, block) in blocks.iter().enumerate() {
|
||||
let decrypted = ck.decrypt_message_and_carry(block);
|
||||
let block_bits = bits_in_block(
|
||||
i,
|
||||
expected_num_blocks,
|
||||
random_bits_count,
|
||||
random_bits_per_block,
|
||||
);
|
||||
combined |= decrypted << shift;
|
||||
shift += block_bits;
|
||||
}
|
||||
|
||||
combined
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user