feat: x8 faster wop, better cache usage

This commit is contained in:
rudy
2022-06-15 09:46:34 +02:00
committed by rudy-6-4
parent b628cd64fb
commit b44bd6cbfe

View File

@@ -87,7 +87,7 @@ impl Solution {
pub struct Tab {
pbs: HashMap<(u64, u64, u64, u64, u64), (f64, Complexity)>,
modulus_switching: HashMap<(u64, u64), (f64, f64)>,
key_switching: HashMap<(u64, u64, u64, u64, u64), (f64, Complexity)>,
key_switching: HashMap<(u64, u64, u64), Vec<(KsDecompositionParameters, (f64, Complexity))>>,
// NEW VALUE MEMOIZED
pp_switching: HashMap<(u64, u64, u64, u64), (f64, Complexity)>,
}
@@ -130,6 +130,8 @@ pub fn tabulate_circuit_bootstrap<W: UnsignedInteger>(
for &internal_dim in internal_lwe_dimensions {
assert!(256 < internal_dim);
let macro_key = (glwe_dim, glwe_poly_size, internal_dim);
let variance_ksk = noise_atomic_pattern::variance_ksk(
internal_dim,
ciphertext_modulus_log,
@@ -180,6 +182,7 @@ pub fn tabulate_circuit_bootstrap<W: UnsignedInteger>(
);
}
let mut ks_seq = Vec::with_capacity(KS_BL_FOR_CB.len());
for &ks_decomposition_parameter in KS_BL_FOR_CB.iter() {
let keyswitch_parameter = KeyswitchParameters {
input_lwe_dimension: LweDimension(glwe_poly_size * glwe_dim),
@@ -196,17 +199,12 @@ pub fn tabulate_circuit_bootstrap<W: UnsignedInteger>(
variance_ksk,
)
.get_variance();
let _ = noise_cost_key_switching.insert(
(
glwe_dim,
glwe_poly_size,
internal_dim,
ks_decomposition_parameter.log2_base,
ks_decomposition_parameter.level,
),
ks_seq.push((
ks_decomposition_parameter,
(noise_keyswitch, complexity_keyswitch),
);
));
}
std::mem::drop(noise_cost_key_switching.insert(macro_key, ks_seq));
// let pp_ks_decomposition_parameter = pbs_parameters.br_decomposition_parameter;
for &pp_ks_decomposition_parameter in BR_BL.iter() {
@@ -327,6 +325,8 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
internal_dim, glwe_poly_size
));
let macro_key = (glwe_dim, glwe_poly_size, internal_dim);
// BlindRotate dans Circuit BS
for &br_decomposition_parameter in BR_BL_FOR_CB.iter() {
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
@@ -348,6 +348,42 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
))
.unwrap();
// new pbs key for the bit extract pbs, shared
let bit_extract_decomposition_parameter = br_decomposition_parameter;
let &(_bit_extract_base_noise, complexity_bit_extract_pbs) = memo
.pbs
.get(&(
glwe_dim,
glwe_poly_size,
internal_dim,
bit_extract_decomposition_parameter.log2_base,
bit_extract_decomposition_parameter.level,
))
.unwrap();
// private packing keyswitch, <=> FP-KS (Circuit Boostrap)
let pp_ks_decomposition_parameter =
pbs_parameters.br_decomposition_parameter;
// for &pp_ks_decomposition_parameter in PP_KS_BL.iter() { // independant params for FP-KS
// Circuit Boostrap
let &(base_noise_private_packing_ks, complexity_ppks) = memo
.pp_switching
.get(&(
glwe_dim,
glwe_poly_size,
pp_ks_decomposition_parameter.log2_base,
pp_ks_decomposition_parameter.level,
))
.expect(&format!(
"{}, {}, {}, {}",
glwe_dim,
glwe_poly_size,
pp_ks_decomposition_parameter.log2_base,
pp_ks_decomposition_parameter.level,
));
// CircuitBootstrap: new parameters l,b
for &circuit_pbs_decomposition_parameter in CB_V1_BL.iter() {
// Hybrid packing
@@ -392,27 +428,8 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
(precision * n_inputs) as f64,
);
// private packing keyswitch, <=> FP-KS (Circuit Boostrap)
let pp_ks_decomposition_parameter =
pbs_parameters.br_decomposition_parameter;
// for &pp_ks_decomposition_parameter in PP_KS_BL.iter() { // independant params for FP-KS
// Circuit Boostrap
let &(mut noise_private_packing_ks, complexity_ppks) = memo
.pp_switching
.get(&(
glwe_dim,
glwe_poly_size,
pp_ks_decomposition_parameter.log2_base,
pp_ks_decomposition_parameter.level,
))
.expect(&format!(
"{}, {}, {}, {}",
glwe_dim,
glwe_poly_size,
pp_ks_decomposition_parameter.log2_base,
pp_ks_decomposition_parameter.level,
));
noise_private_packing_ks += base_noise / 2.;
let noise_private_packing_ks =
base_noise_private_packing_ks + base_noise / 2.;
// Circuit Boostrap
if noise_private_packing_ks + noise_modulus_switching > variance_max
@@ -448,23 +465,17 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
(2_f64.powf(2. * log_norm as f64)) * noise_cmux_tree_blind_rotate; // out noise * weights
// Shared by all pbs (like brs)
for &ks_decomposition_parameter in KS_BL_FOR_CB.iter() {
let key_switching_q = memo.key_switching.get(&macro_key).unwrap();
for &(
ks_decomposition_parameter,
(noise_keyswitch, complexity_keyswitch),
) in key_switching_q.iter()
{
let keyswitch_parameter = KeyswitchParameters {
input_lwe_dimension: LweDimension(glwe_poly_size * glwe_dim),
output_lwe_dimension: LweDimension(internal_dim),
ks_decomposition_parameter,
};
let &(noise_keyswitch, complexity_keyswitch) = memo
.key_switching
.get(&(
glwe_dim,
glwe_poly_size,
internal_dim,
ks_decomposition_parameter.log2_base,
ks_decomposition_parameter.level,
))
.unwrap();
let complexity_all_ks = precision as f64 * complexity_keyswitch;
if noise_private_packing_ks
+ noise_modulus_switching
@@ -478,21 +489,6 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
continue;
}
// new pbs key for the bit extract pbs, shared
let bit_extract_decomposition_parameter =
br_decomposition_parameter;
let &(_bit_extract_base_noise, complexity_bit_extract_pbs) = memo
.pbs
.get(&(
glwe_dim,
glwe_poly_size,
internal_dim,
bit_extract_decomposition_parameter.log2_base,
bit_extract_decomposition_parameter.level,
))
.unwrap();
// noise_multisum = dot
let current_maximal_noise = noise_multisum
* (1 << (2 * (precision - 1))) as f64
@@ -514,6 +510,11 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
* (1 << n_functions_log) as f64
+ complexity_bias;
if current_complexity > best_complexity {
// next level is more costly
break;
}
if current_complexity < best_complexity
&& current_maximal_noise < variance_max
{
@@ -558,9 +559,6 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
circuit_pbs_decomposition_parameter.log2_base,
),
});
} else if current_complexity > best_complexity {
// next level is more costly
break;
}
}
}