mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(woppbs): make code more readable, reexpress some cuts (x4 speed)
This commit is contained in:
@@ -86,7 +86,7 @@ impl Solution {
|
||||
#[derive(Debug)]
|
||||
pub struct Tab {
|
||||
pbs: HashMap<(u64, u64, u64, u64, u64), (f64, Complexity)>,
|
||||
modulus_switching: HashMap<(u64, u64), (f64, f64)>,
|
||||
modulus_switching: HashMap<(u64, u64), f64>,
|
||||
key_switching: HashMap<(u64, u64, u64), Vec<(KsDecompositionParameters, (f64, Complexity))>>,
|
||||
// NEW VALUE MEMOIZED
|
||||
pp_switching: HashMap<(u64, u64, u64, u64), (f64, Complexity)>,
|
||||
@@ -143,10 +143,8 @@ pub fn tabulate_circuit_bootstrap<W: UnsignedInteger>(
|
||||
glwe_params.polynomial_size(),
|
||||
)
|
||||
.get_variance();
|
||||
let _ = noise_cost_modulus_switching.insert(
|
||||
(internal_dim, glwe_poly_size),
|
||||
(noise_modulus_switching, 0.),
|
||||
);
|
||||
let _ = noise_cost_modulus_switching
|
||||
.insert((internal_dim, glwe_poly_size), noise_modulus_switching);
|
||||
|
||||
for &br_decomposition_parameter in BR_BL.iter() {
|
||||
let pbs_parameters = PbsParameters {
|
||||
@@ -206,7 +204,6 @@ pub fn tabulate_circuit_bootstrap<W: UnsignedInteger>(
|
||||
}
|
||||
std::mem::drop(noise_cost_key_switching.insert(macro_key, ks_seq));
|
||||
|
||||
// let pp_ks_decomposition_parameter = pbs_parameters.br_decomposition_parameter;
|
||||
for &pp_ks_decomposition_parameter in BR_BL.iter() {
|
||||
let ppks_parameter = PbsParameters {
|
||||
internal_lwe_dimension: LweDimension(
|
||||
@@ -274,15 +271,15 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
glwe_log_polynomial_sizes: &[u64],
|
||||
glwe_dimensions: &[u64],
|
||||
internal_lwe_dimensions: &[u64],
|
||||
n_functions_log: u64, // Many functions at the same time, stay at 1 for start
|
||||
n_functions: u64, // Many functions at the same time, stay at 1 for start
|
||||
memo: &Tab,
|
||||
n_inputs: u64, // Tau (nb blocks)
|
||||
) -> OptimizationState {
|
||||
assert!(n_functions_log == 0); // update complexity scaling
|
||||
assert!(0.0 < maximum_acceptable_error_probability);
|
||||
assert!(maximum_acceptable_error_probability < 1.0);
|
||||
|
||||
let ciphertext_modulus_log = W::BITS as u64;
|
||||
let global_precision = n_inputs * precision;
|
||||
|
||||
// Circuit BS bound
|
||||
// 1 bit of message only here =)
|
||||
@@ -316,8 +313,10 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
glwe_dimension: glwe_dim,
|
||||
};
|
||||
|
||||
let input_lwe_dimension = glwe_params.lwe_dimension();
|
||||
|
||||
for &internal_dim in internal_lwe_dimensions {
|
||||
let &(noise_modulus_switching, _) = memo
|
||||
let &noise_modulus_switching = memo
|
||||
.modulus_switching
|
||||
.get(&(internal_dim, glwe_poly_size))
|
||||
.expect(&format!(
|
||||
@@ -325,6 +324,10 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
internal_dim, glwe_poly_size
|
||||
));
|
||||
|
||||
if noise_modulus_switching > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
let macro_key = (glwe_dim, glwe_poly_size, internal_dim);
|
||||
|
||||
// BlindRotate dans Circuit BS
|
||||
@@ -362,10 +365,16 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let complexity_bit_extract_wo_ks =
|
||||
(n_inputs * (precision - 1)) as f64 * complexity_bit_extract_pbs;
|
||||
|
||||
if complexity_bit_extract_wo_ks > best_complexity {
|
||||
continue;
|
||||
}
|
||||
|
||||
// private packing keyswitch, <=> FP-KS (Circuit Boostrap)
|
||||
let pp_ks_decomposition_parameter =
|
||||
pbs_parameters.br_decomposition_parameter;
|
||||
// for &pp_ks_decomposition_parameter in PP_KS_BL.iter() { // independant params for FP-KS
|
||||
|
||||
// Circuit Boostrap
|
||||
let &(base_noise_private_packing_ks, complexity_ppks) = memo
|
||||
@@ -394,52 +403,63 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
output_glwe_params: pbs_parameters.output_glwe_params,
|
||||
};
|
||||
|
||||
// Hybrid packing (rename)
|
||||
let complexity_cmux_for_cb = DEFAULT_COMPLEXITY.pbs.complexity(
|
||||
// Hybrid packing
|
||||
let complexity_1_cmux_hp = DEFAULT_COMPLEXITY.pbs.complexity(
|
||||
cmux_tree_blind_rotate_parameters,
|
||||
ciphertext_modulus_log,
|
||||
); // TODO: missing fft transform
|
||||
|
||||
// Hybrid packing (Do we have 1 or 2 groups)
|
||||
#[allow(clippy::precedence)]
|
||||
let complexity_cmux_tree = if precision * n_inputs as u64 // sum of precisions
|
||||
> pbs_parameters.output_glwe_params.log2_polynomial_size
|
||||
{
|
||||
// 2 groups
|
||||
complexity_cmux_for_cb
|
||||
* (1 << (precision * n_inputs
|
||||
- pbs_parameters.output_glwe_params.log2_polynomial_size)
|
||||
- 1) as f64
|
||||
// * (f64::exp2(
|
||||
// (precision * n_inputs
|
||||
// - pbs_parameters
|
||||
// .output_glwe_params
|
||||
// .log2_polynomial_size)
|
||||
// as f64,
|
||||
// ) - 1.)
|
||||
let log2_polynomial_size =
|
||||
pbs_parameters.output_glwe_params.log2_polynomial_size;
|
||||
// Size of cmux_group, can be zero
|
||||
let cmux_group_count = if global_precision > log2_polynomial_size {
|
||||
2f64.powi((global_precision - log2_polynomial_size - 1) as i32)
|
||||
} else {
|
||||
// 1 group, no cmux tree
|
||||
0.
|
||||
0.0
|
||||
};
|
||||
let complexity_cmux_tree =
|
||||
cmux_group_count as f64 * complexity_1_cmux_hp;
|
||||
// Hybrid packing blind rotate
|
||||
let complexity_g_br = complexity_cmux_for_cb
|
||||
* f64::min(
|
||||
(pbs_parameters.output_glwe_params.log2_polynomial_size) as f64,
|
||||
(precision * n_inputs) as f64,
|
||||
);
|
||||
let complexity_g_br = complexity_1_cmux_hp
|
||||
* u64::min(
|
||||
pbs_parameters.output_glwe_params.log2_polynomial_size,
|
||||
global_precision,
|
||||
) as f64;
|
||||
|
||||
let noise_private_packing_ks =
|
||||
base_noise_private_packing_ks + base_noise / 2.;
|
||||
let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br;
|
||||
let complexity_multi_hybrid_packing =
|
||||
n_functions as f64 * complexity_hybrid_packing;
|
||||
|
||||
// Circuit Boostrap
|
||||
if noise_private_packing_ks + noise_modulus_switching > variance_max
|
||||
|| (precision - 1) as f64 * complexity_pbs + complexity_ppks
|
||||
> best_complexity
|
||||
// Circuit bs: fp-ks
|
||||
let complexity_all_ppks =
|
||||
((pbs_parameters.output_glwe_params.glwe_dimension + 1)
|
||||
* circuit_pbs_decomposition_parameter.level
|
||||
* precision
|
||||
* n_inputs) as f64
|
||||
* complexity_ppks;
|
||||
|
||||
// Circuit bs: pbs
|
||||
let complexity_all_pbs =
|
||||
(n_inputs * precision * circuit_pbs_decomposition_parameter.level)
|
||||
as f64
|
||||
* complexity_pbs;
|
||||
|
||||
let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks;
|
||||
|
||||
if complexity_bit_extract_wo_ks + complexity_circuit_bs
|
||||
> best_complexity
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let noise_ggsw = noise_private_packing_ks;
|
||||
let noise_ggsw = base_noise_private_packing_ks + base_noise / 2.;
|
||||
|
||||
// Circuit Boostrap
|
||||
let noise_hybrid_packing = noise_modulus_switching + noise_ggsw;
|
||||
if noise_hybrid_packing > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
let noise_one_external_product_for_cmux_tree =
|
||||
noise_atomic_pattern::variance_bootstrap::<W>(
|
||||
@@ -449,13 +469,6 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
// all fp-ks
|
||||
let complexity_all_ppks =
|
||||
((pbs_parameters.output_glwe_params.glwe_dimension + 1)
|
||||
* circuit_pbs_decomposition_parameter.level
|
||||
* precision) as f64
|
||||
* complexity_ppks;
|
||||
|
||||
// final out noise hybrid packing
|
||||
let noise_cmux_tree_blind_rotate =
|
||||
noise_one_external_product_for_cmux_tree
|
||||
@@ -464,94 +477,64 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
let noise_multisum =
|
||||
(2_f64.powf(2. * log_norm as f64)) * noise_cmux_tree_blind_rotate; // out noise * weights
|
||||
|
||||
let noise_all_multisum =
|
||||
noise_multisum * (1 << (2 * (precision - 1))) as f64;
|
||||
|
||||
let noise_ggsw_reencoding =
|
||||
noise_modulus_switching + noise_all_multisum;
|
||||
if noise_ggsw_reencoding > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
let noise_max = noise_ggsw_reencoding.max(noise_hybrid_packing);
|
||||
|
||||
// Shared by all pbs (like brs)
|
||||
let key_switching_q = memo.key_switching.get(¯o_key).unwrap();
|
||||
for &(
|
||||
ks_decomposition_parameter,
|
||||
(noise_keyswitch, complexity_keyswitch),
|
||||
) in key_switching_q.iter()
|
||||
) in key_switching_q
|
||||
{
|
||||
let keyswitch_parameter = KeyswitchParameters {
|
||||
input_lwe_dimension: LweDimension(glwe_poly_size * glwe_dim),
|
||||
output_lwe_dimension: LweDimension(internal_dim),
|
||||
ks_decomposition_parameter,
|
||||
};
|
||||
let complexity_all_ks = precision as f64 * complexity_keyswitch;
|
||||
if noise_private_packing_ks
|
||||
+ noise_modulus_switching
|
||||
+ noise_keyswitch
|
||||
> variance_max
|
||||
|| (precision - 1) as f64 * complexity_pbs
|
||||
+ complexity_ppks
|
||||
+ precision as f64 * complexity_keyswitch
|
||||
> best_complexity
|
||||
{
|
||||
let noise_max = noise_max + noise_keyswitch;
|
||||
if noise_max > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
// noise_multisum = dot
|
||||
let current_maximal_noise = noise_multisum
|
||||
* (1 << (2 * (precision - 1))) as f64
|
||||
+ noise_keyswitch
|
||||
+ noise_modulus_switching;
|
||||
let complexity_all_ks =
|
||||
(precision * n_inputs) as f64 * complexity_keyswitch;
|
||||
let complexity_bit_extract =
|
||||
complexity_bit_extract_wo_ks + complexity_all_ks;
|
||||
|
||||
let complexity_all_pbs =
|
||||
(precision * circuit_pbs_decomposition_parameter.level) as f64
|
||||
* complexity_pbs
|
||||
+ (precision - 1) as f64 * complexity_bit_extract_pbs;
|
||||
let complexity_ggsw_reencoding =
|
||||
complexity_bit_extract + complexity_circuit_bs;
|
||||
|
||||
let complexity_bias =
|
||||
(complexity_all_ppks + complexity_all_pbs + complexity_all_ks)
|
||||
* n_inputs as f64;
|
||||
let complexity =
|
||||
complexity_ggsw_reencoding + complexity_multi_hybrid_packing;
|
||||
|
||||
let complexity_slope = complexity_cmux_tree + complexity_g_br;
|
||||
|
||||
let current_complexity = complexity_slope
|
||||
* (1 << n_functions_log) as f64
|
||||
+ complexity_bias;
|
||||
|
||||
if current_complexity > best_complexity {
|
||||
// next level is more costly
|
||||
if complexity > best_complexity {
|
||||
// next ks.level will be even more costly
|
||||
break;
|
||||
}
|
||||
|
||||
if current_complexity < best_complexity
|
||||
&& current_maximal_noise < variance_max
|
||||
{
|
||||
best_complexity = current_complexity;
|
||||
if complexity < best_complexity {
|
||||
best_complexity = complexity;
|
||||
let p_error = find_p_error(kappa, variance_max, noise_max);
|
||||
state.best_solution = Some(Solution {
|
||||
input_lwe_dimension: pbs_parameters
|
||||
.output_glwe_params
|
||||
.glwe_dimension
|
||||
* pbs_parameters.output_glwe_params.polynomial_size(),
|
||||
internal_ks_output_lwe_dimension: keyswitch_parameter
|
||||
.output_lwe_dimension
|
||||
.0,
|
||||
ks_decomposition_level_count: keyswitch_parameter
|
||||
.ks_decomposition_parameter
|
||||
input_lwe_dimension,
|
||||
internal_ks_output_lwe_dimension: internal_dim,
|
||||
ks_decomposition_level_count: ks_decomposition_parameter
|
||||
.level,
|
||||
ks_decomposition_base_log: keyswitch_parameter
|
||||
.ks_decomposition_parameter
|
||||
ks_decomposition_base_log: ks_decomposition_parameter
|
||||
.log2_base,
|
||||
glwe_polynomial_size: pbs_parameters
|
||||
.output_glwe_params
|
||||
.polynomial_size(),
|
||||
glwe_dimension: pbs_parameters
|
||||
.output_glwe_params
|
||||
.glwe_dimension,
|
||||
br_decomposition_level_count: pbs_parameters
|
||||
.br_decomposition_parameter
|
||||
glwe_polynomial_size: glwe_poly_size,
|
||||
glwe_dimension: glwe_dim,
|
||||
br_decomposition_level_count: br_decomposition_parameter
|
||||
.level,
|
||||
br_decomposition_base_log: pbs_parameters
|
||||
.br_decomposition_parameter
|
||||
br_decomposition_base_log: br_decomposition_parameter
|
||||
.log2_base,
|
||||
noise_max: current_maximal_noise,
|
||||
complexity: current_complexity,
|
||||
p_error: find_p_error(
|
||||
kappa,
|
||||
variance_max,
|
||||
current_maximal_noise,
|
||||
), // consts.maximum_acceptable_error_probability,
|
||||
noise_max,
|
||||
complexity,
|
||||
p_error,
|
||||
cb_decomposition_level_count: Some(
|
||||
circuit_pbs_decomposition_parameter.level,
|
||||
),
|
||||
@@ -571,6 +554,30 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
state
|
||||
}
|
||||
|
||||
// Default heuristic to split in several word
|
||||
pub fn default_partitionning(precision: u64) -> Vec<u64> {
|
||||
#[allow(clippy::match_same_arms)]
|
||||
match precision {
|
||||
1 => vec![1],
|
||||
2 => vec![2],
|
||||
3 => vec![2; 2],
|
||||
4 => vec![3; 2],
|
||||
5 => vec![3; 2],
|
||||
6 => vec![3; 3],
|
||||
7 => vec![3; 3],
|
||||
8 => vec![3; 3],
|
||||
9 => vec![4; 3],
|
||||
10 => vec![4; 3],
|
||||
11 => vec![4; 3],
|
||||
12 => vec![4; 4],
|
||||
13 => vec![4; 4],
|
||||
14 => vec![4; 4],
|
||||
15 => vec![4; 4],
|
||||
16 => vec![5; 4],
|
||||
_ => vec![5; (precision / 5) as usize],
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn optimize_one<W: UnsignedInteger>(
|
||||
_sum_size: u64,
|
||||
@@ -583,33 +590,11 @@ pub fn optimize_one<W: UnsignedInteger>(
|
||||
internal_lwe_dimensions: &[u64],
|
||||
memo_opt: &mut Option<Tab>,
|
||||
) -> atomic_pattern::OptimizationState {
|
||||
// Basic heuristic to split in several word
|
||||
let no_sol = atomic_pattern::OptimizationState {
|
||||
best_solution: None,
|
||||
count_domain: 0,
|
||||
};
|
||||
#[allow(clippy::match_same_arms)]
|
||||
let (nb_words, max_word_precision) = match precision {
|
||||
1 => (1, 1),
|
||||
2 => (1, 2),
|
||||
3 => (2, 2),
|
||||
4 => (2, 3),
|
||||
5 => (2, 3),
|
||||
6 => (3, 3),
|
||||
7 => (3, 3),
|
||||
8 => (3, 3),
|
||||
9 => (3, 4),
|
||||
10 => (3, 4),
|
||||
11 => (3, 4),
|
||||
12 => (4, 4),
|
||||
13 => (4, 4),
|
||||
14 => (4, 4),
|
||||
15 => (4, 4),
|
||||
16 => (4, 5),
|
||||
_ => return no_sol,
|
||||
};
|
||||
let partitionning = default_partitionning(precision);
|
||||
let nb_words = partitionning.len() as u64;
|
||||
let max_word_precision = *partitionning.iter().max().unwrap() as u64;
|
||||
let log_norm = noise_factor.log2();
|
||||
let n_functions_log = 0;
|
||||
let n_functions = 1;
|
||||
let memo = memo_opt.get_or_insert_with(|| {
|
||||
tabulate_circuit_bootstrap::<W>(
|
||||
security_level,
|
||||
@@ -627,7 +612,7 @@ pub fn optimize_one<W: UnsignedInteger>(
|
||||
glwe_log_polynomial_sizes,
|
||||
glwe_dimensions,
|
||||
internal_lwe_dimensions,
|
||||
n_functions_log,
|
||||
n_functions,
|
||||
memo,
|
||||
nb_words, // Tau
|
||||
);
|
||||
|
||||
@@ -26,6 +26,9 @@ mod individual {
|
||||
pub fn polynomial_size(self) -> u64 {
|
||||
1 << self.log2_polynomial_size
|
||||
}
|
||||
pub fn lwe_dimension(self) -> u64 {
|
||||
self.glwe_dimension << self.log2_polynomial_size
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
|
||||
Reference in New Issue
Block a user