feat(woppbs): make code more readable, reexpress some cuts (x4 speed)

This commit is contained in:
rudy
2022-06-16 08:59:11 +02:00
committed by rudy-6-4
parent 91a4e6eea3
commit 5d5f7a4016
2 changed files with 132 additions and 144 deletions

View File

@@ -86,7 +86,7 @@ impl Solution {
#[derive(Debug)]
pub struct Tab {
pbs: HashMap<(u64, u64, u64, u64, u64), (f64, Complexity)>,
modulus_switching: HashMap<(u64, u64), (f64, f64)>,
modulus_switching: HashMap<(u64, u64), f64>,
key_switching: HashMap<(u64, u64, u64), Vec<(KsDecompositionParameters, (f64, Complexity))>>,
// NEW VALUE MEMOIZED
pp_switching: HashMap<(u64, u64, u64, u64), (f64, Complexity)>,
@@ -143,10 +143,8 @@ pub fn tabulate_circuit_bootstrap<W: UnsignedInteger>(
glwe_params.polynomial_size(),
)
.get_variance();
let _ = noise_cost_modulus_switching.insert(
(internal_dim, glwe_poly_size),
(noise_modulus_switching, 0.),
);
let _ = noise_cost_modulus_switching
.insert((internal_dim, glwe_poly_size), noise_modulus_switching);
for &br_decomposition_parameter in BR_BL.iter() {
let pbs_parameters = PbsParameters {
@@ -206,7 +204,6 @@ pub fn tabulate_circuit_bootstrap<W: UnsignedInteger>(
}
std::mem::drop(noise_cost_key_switching.insert(macro_key, ks_seq));
// let pp_ks_decomposition_parameter = pbs_parameters.br_decomposition_parameter;
for &pp_ks_decomposition_parameter in BR_BL.iter() {
let ppks_parameter = PbsParameters {
internal_lwe_dimension: LweDimension(
@@ -274,15 +271,15 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
glwe_log_polynomial_sizes: &[u64],
glwe_dimensions: &[u64],
internal_lwe_dimensions: &[u64],
n_functions_log: u64, // Many functions at the same time, stay at 1 for start
n_functions: u64, // Many functions at the same time, stay at 1 for start
memo: &Tab,
n_inputs: u64, // Tau (nb blocks)
) -> OptimizationState {
assert!(n_functions_log == 0); // update complexity scaling
assert!(0.0 < maximum_acceptable_error_probability);
assert!(maximum_acceptable_error_probability < 1.0);
let ciphertext_modulus_log = W::BITS as u64;
let global_precision = n_inputs * precision;
// Circuit BS bound
// 1 bit of message only here =)
@@ -316,8 +313,10 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
glwe_dimension: glwe_dim,
};
let input_lwe_dimension = glwe_params.lwe_dimension();
for &internal_dim in internal_lwe_dimensions {
let &(noise_modulus_switching, _) = memo
let &noise_modulus_switching = memo
.modulus_switching
.get(&(internal_dim, glwe_poly_size))
.expect(&format!(
@@ -325,6 +324,10 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
internal_dim, glwe_poly_size
));
if noise_modulus_switching > variance_max {
continue;
}
let macro_key = (glwe_dim, glwe_poly_size, internal_dim);
// BlindRotate dans Circuit BS
@@ -362,10 +365,16 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
))
.unwrap();
let complexity_bit_extract_wo_ks =
(n_inputs * (precision - 1)) as f64 * complexity_bit_extract_pbs;
if complexity_bit_extract_wo_ks > best_complexity {
continue;
}
// private packing keyswitch, <=> FP-KS (Circuit Boostrap)
let pp_ks_decomposition_parameter =
pbs_parameters.br_decomposition_parameter;
// for &pp_ks_decomposition_parameter in PP_KS_BL.iter() { // independant params for FP-KS
// Circuit Boostrap
let &(base_noise_private_packing_ks, complexity_ppks) = memo
@@ -394,52 +403,63 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
output_glwe_params: pbs_parameters.output_glwe_params,
};
// Hybrid packing (rename)
let complexity_cmux_for_cb = DEFAULT_COMPLEXITY.pbs.complexity(
// Hybrid packing
let complexity_1_cmux_hp = DEFAULT_COMPLEXITY.pbs.complexity(
cmux_tree_blind_rotate_parameters,
ciphertext_modulus_log,
); // TODO: missing fft transform
// Hybrid packing (Do we have 1 or 2 groups)
#[allow(clippy::precedence)]
let complexity_cmux_tree = if precision * n_inputs as u64 // sum of precisions
> pbs_parameters.output_glwe_params.log2_polynomial_size
{
// 2 groups
complexity_cmux_for_cb
* (1 << (precision * n_inputs
- pbs_parameters.output_glwe_params.log2_polynomial_size)
- 1) as f64
// * (f64::exp2(
// (precision * n_inputs
// - pbs_parameters
// .output_glwe_params
// .log2_polynomial_size)
// as f64,
// ) - 1.)
let log2_polynomial_size =
pbs_parameters.output_glwe_params.log2_polynomial_size;
// Size of cmux_group, can be zero
let cmux_group_count = if global_precision > log2_polynomial_size {
2f64.powi((global_precision - log2_polynomial_size - 1) as i32)
} else {
// 1 group, no cmux tree
0.
0.0
};
let complexity_cmux_tree =
cmux_group_count as f64 * complexity_1_cmux_hp;
// Hybrid packing blind rotate
let complexity_g_br = complexity_cmux_for_cb
* f64::min(
(pbs_parameters.output_glwe_params.log2_polynomial_size) as f64,
(precision * n_inputs) as f64,
);
let complexity_g_br = complexity_1_cmux_hp
* u64::min(
pbs_parameters.output_glwe_params.log2_polynomial_size,
global_precision,
) as f64;
let noise_private_packing_ks =
base_noise_private_packing_ks + base_noise / 2.;
let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br;
let complexity_multi_hybrid_packing =
n_functions as f64 * complexity_hybrid_packing;
// Circuit Boostrap
if noise_private_packing_ks + noise_modulus_switching > variance_max
|| (precision - 1) as f64 * complexity_pbs + complexity_ppks
> best_complexity
// Circuit bs: fp-ks
let complexity_all_ppks =
((pbs_parameters.output_glwe_params.glwe_dimension + 1)
* circuit_pbs_decomposition_parameter.level
* precision
* n_inputs) as f64
* complexity_ppks;
// Circuit bs: pbs
let complexity_all_pbs =
(n_inputs * precision * circuit_pbs_decomposition_parameter.level)
as f64
* complexity_pbs;
let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks;
if complexity_bit_extract_wo_ks + complexity_circuit_bs
> best_complexity
{
continue;
}
let noise_ggsw = noise_private_packing_ks;
let noise_ggsw = base_noise_private_packing_ks + base_noise / 2.;
// Circuit Boostrap
let noise_hybrid_packing = noise_modulus_switching + noise_ggsw;
if noise_hybrid_packing > variance_max {
continue;
}
let noise_one_external_product_for_cmux_tree =
noise_atomic_pattern::variance_bootstrap::<W>(
@@ -449,13 +469,6 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
)
.get_variance();
// all fp-ks
let complexity_all_ppks =
((pbs_parameters.output_glwe_params.glwe_dimension + 1)
* circuit_pbs_decomposition_parameter.level
* precision) as f64
* complexity_ppks;
// final out noise hybrid packing
let noise_cmux_tree_blind_rotate =
noise_one_external_product_for_cmux_tree
@@ -464,94 +477,64 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
let noise_multisum =
(2_f64.powf(2. * log_norm as f64)) * noise_cmux_tree_blind_rotate; // out noise * weights
let noise_all_multisum =
noise_multisum * (1 << (2 * (precision - 1))) as f64;
let noise_ggsw_reencoding =
noise_modulus_switching + noise_all_multisum;
if noise_ggsw_reencoding > variance_max {
continue;
}
let noise_max = noise_ggsw_reencoding.max(noise_hybrid_packing);
// Shared by all pbs (like brs)
let key_switching_q = memo.key_switching.get(&macro_key).unwrap();
for &(
ks_decomposition_parameter,
(noise_keyswitch, complexity_keyswitch),
) in key_switching_q.iter()
) in key_switching_q
{
let keyswitch_parameter = KeyswitchParameters {
input_lwe_dimension: LweDimension(glwe_poly_size * glwe_dim),
output_lwe_dimension: LweDimension(internal_dim),
ks_decomposition_parameter,
};
let complexity_all_ks = precision as f64 * complexity_keyswitch;
if noise_private_packing_ks
+ noise_modulus_switching
+ noise_keyswitch
> variance_max
|| (precision - 1) as f64 * complexity_pbs
+ complexity_ppks
+ precision as f64 * complexity_keyswitch
> best_complexity
{
let noise_max = noise_max + noise_keyswitch;
if noise_max > variance_max {
continue;
}
// noise_multisum = dot
let current_maximal_noise = noise_multisum
* (1 << (2 * (precision - 1))) as f64
+ noise_keyswitch
+ noise_modulus_switching;
let complexity_all_ks =
(precision * n_inputs) as f64 * complexity_keyswitch;
let complexity_bit_extract =
complexity_bit_extract_wo_ks + complexity_all_ks;
let complexity_all_pbs =
(precision * circuit_pbs_decomposition_parameter.level) as f64
* complexity_pbs
+ (precision - 1) as f64 * complexity_bit_extract_pbs;
let complexity_ggsw_reencoding =
complexity_bit_extract + complexity_circuit_bs;
let complexity_bias =
(complexity_all_ppks + complexity_all_pbs + complexity_all_ks)
* n_inputs as f64;
let complexity =
complexity_ggsw_reencoding + complexity_multi_hybrid_packing;
let complexity_slope = complexity_cmux_tree + complexity_g_br;
let current_complexity = complexity_slope
* (1 << n_functions_log) as f64
+ complexity_bias;
if current_complexity > best_complexity {
// next level is more costly
if complexity > best_complexity {
// next ks.level will be even more costly
break;
}
if current_complexity < best_complexity
&& current_maximal_noise < variance_max
{
best_complexity = current_complexity;
if complexity < best_complexity {
best_complexity = complexity;
let p_error = find_p_error(kappa, variance_max, noise_max);
state.best_solution = Some(Solution {
input_lwe_dimension: pbs_parameters
.output_glwe_params
.glwe_dimension
* pbs_parameters.output_glwe_params.polynomial_size(),
internal_ks_output_lwe_dimension: keyswitch_parameter
.output_lwe_dimension
.0,
ks_decomposition_level_count: keyswitch_parameter
.ks_decomposition_parameter
input_lwe_dimension,
internal_ks_output_lwe_dimension: internal_dim,
ks_decomposition_level_count: ks_decomposition_parameter
.level,
ks_decomposition_base_log: keyswitch_parameter
.ks_decomposition_parameter
ks_decomposition_base_log: ks_decomposition_parameter
.log2_base,
glwe_polynomial_size: pbs_parameters
.output_glwe_params
.polynomial_size(),
glwe_dimension: pbs_parameters
.output_glwe_params
.glwe_dimension,
br_decomposition_level_count: pbs_parameters
.br_decomposition_parameter
glwe_polynomial_size: glwe_poly_size,
glwe_dimension: glwe_dim,
br_decomposition_level_count: br_decomposition_parameter
.level,
br_decomposition_base_log: pbs_parameters
.br_decomposition_parameter
br_decomposition_base_log: br_decomposition_parameter
.log2_base,
noise_max: current_maximal_noise,
complexity: current_complexity,
p_error: find_p_error(
kappa,
variance_max,
current_maximal_noise,
), // consts.maximum_acceptable_error_probability,
noise_max,
complexity,
p_error,
cb_decomposition_level_count: Some(
circuit_pbs_decomposition_parameter.level,
),
@@ -571,6 +554,30 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
state
}
// Default heuristic to split in several word
pub fn default_partitionning(precision: u64) -> Vec<u64> {
#[allow(clippy::match_same_arms)]
match precision {
1 => vec![1],
2 => vec![2],
3 => vec![2; 2],
4 => vec![3; 2],
5 => vec![3; 2],
6 => vec![3; 3],
7 => vec![3; 3],
8 => vec![3; 3],
9 => vec![4; 3],
10 => vec![4; 3],
11 => vec![4; 3],
12 => vec![4; 4],
13 => vec![4; 4],
14 => vec![4; 4],
15 => vec![4; 4],
16 => vec![5; 4],
_ => vec![5; (precision / 5) as usize],
}
}
#[allow(clippy::too_many_lines)]
pub fn optimize_one<W: UnsignedInteger>(
_sum_size: u64,
@@ -583,33 +590,11 @@ pub fn optimize_one<W: UnsignedInteger>(
internal_lwe_dimensions: &[u64],
memo_opt: &mut Option<Tab>,
) -> atomic_pattern::OptimizationState {
// Basic heuristic to split in several word
let no_sol = atomic_pattern::OptimizationState {
best_solution: None,
count_domain: 0,
};
#[allow(clippy::match_same_arms)]
let (nb_words, max_word_precision) = match precision {
1 => (1, 1),
2 => (1, 2),
3 => (2, 2),
4 => (2, 3),
5 => (2, 3),
6 => (3, 3),
7 => (3, 3),
8 => (3, 3),
9 => (3, 4),
10 => (3, 4),
11 => (3, 4),
12 => (4, 4),
13 => (4, 4),
14 => (4, 4),
15 => (4, 4),
16 => (4, 5),
_ => return no_sol,
};
let partitionning = default_partitionning(precision);
let nb_words = partitionning.len() as u64;
let max_word_precision = *partitionning.iter().max().unwrap() as u64;
let log_norm = noise_factor.log2();
let n_functions_log = 0;
let n_functions = 1;
let memo = memo_opt.get_or_insert_with(|| {
tabulate_circuit_bootstrap::<W>(
security_level,
@@ -627,7 +612,7 @@ pub fn optimize_one<W: UnsignedInteger>(
glwe_log_polynomial_sizes,
glwe_dimensions,
internal_lwe_dimensions,
n_functions_log,
n_functions,
memo,
nb_words, // Tau
);

View File

@@ -26,6 +26,9 @@ mod individual {
pub fn polynomial_size(self) -> u64 {
1 << self.log2_polynomial_size
}
pub fn lwe_dimension(self) -> u64 {
self.glwe_dimension << self.log2_polynomial_size
}
}
#[derive(Clone, Copy)]