mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
chore(woppbs): cut optimize_one in two macro/micro
this is a first step to use pareto_blind_rotate, pareto_keyswitch
This commit is contained in:
@@ -10,6 +10,7 @@ use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
|
||||
use crate::noise_estimator::operators::wop_atomic_pattern::estimate_packing_private_keyswitch;
|
||||
|
||||
use crate::optimization::atomic_pattern;
|
||||
use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts;
|
||||
use crate::optimization::wop_atomic_pattern::pareto::{
|
||||
BR_BL, BR_BL_FOR_CB, CB_V1_BL, KS_BL, KS_BL_FOR_CB,
|
||||
};
|
||||
@@ -35,7 +36,7 @@ pub struct OptimizationState {
|
||||
pub count_domain: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Solution {
|
||||
pub input_lwe_dimension: u64,
|
||||
//n_big
|
||||
@@ -193,14 +194,212 @@ fn compute_noise_cost_by_micro_param<W: UnsignedInteger>(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
state: &mut OptimizationState,
|
||||
consts: &OptimizationDecompositionsConsts,
|
||||
glwe_params: GlweParameters,
|
||||
internal_dim: u64,
|
||||
n_functions: u64,
|
||||
precision: u64,
|
||||
n_inputs: u64, // Tau
|
||||
) {
|
||||
let ciphertext_modulus_log = consts.ciphertext_modulus_log;
|
||||
let global_precision = n_inputs * precision;
|
||||
let variance_max = consts.safe_variance;
|
||||
let log_norm = consts.noise_factor.log2();
|
||||
|
||||
let micro_tab =
|
||||
compute_noise_cost_by_micro_param::<W>(consts.security_level, glwe_params, internal_dim);
|
||||
|
||||
let noise_modulus_switching =
|
||||
noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key::<W>(
|
||||
internal_dim,
|
||||
glwe_params.polynomial_size(),
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
if noise_modulus_switching > consts.safe_variance {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut best_complexity = state.best_solution.map_or(f64::INFINITY, |s| s.complexity);
|
||||
|
||||
// BlindRotate dans Circuit BS
|
||||
for (br_dp_index, &br_decomposition_parameter) in BR_BL_FOR_CB.iter().enumerate() {
|
||||
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
|
||||
// TODO: choisir indépendemment(separate FP-KS)
|
||||
let pbs_parameters = PbsParameters {
|
||||
internal_lwe_dimension: LweDimension(internal_dim),
|
||||
br_decomposition_parameter,
|
||||
output_glwe_params: glwe_params,
|
||||
};
|
||||
|
||||
let (base_noise, complexity_pbs) = micro_tab.pbs[br_dp_index];
|
||||
|
||||
// new pbs key for the bit extract pbs, shared
|
||||
let bit_extract_dp_index = br_dp_index;
|
||||
|
||||
let (_bit_extract_base_noise, complexity_bit_extract_pbs) =
|
||||
micro_tab.pbs[bit_extract_dp_index];
|
||||
|
||||
let complexity_bit_extract_wo_ks =
|
||||
(n_inputs * (precision - 1)) as f64 * complexity_bit_extract_pbs;
|
||||
|
||||
if complexity_bit_extract_wo_ks > best_complexity {
|
||||
continue;
|
||||
}
|
||||
|
||||
// private packing keyswitch, <=> FP-KS (Circuit Boostrap)
|
||||
let pp_ks_dp_index = br_dp_index;
|
||||
|
||||
// Circuit Boostrap
|
||||
let (base_noise_private_packing_ks, complexity_ppks) =
|
||||
micro_tab.pp_switching[pp_ks_dp_index];
|
||||
|
||||
// CircuitBootstrap: new parameters l,b
|
||||
for &circuit_pbs_decomposition_parameter in CB_V1_BL.iter() {
|
||||
// Hybrid packing
|
||||
let nb_cmux = 1_u64;
|
||||
let cmux_tree_blind_rotate_parameters = PbsParameters {
|
||||
internal_lwe_dimension: LweDimension(nb_cmux), // complexity for 1 cmux
|
||||
br_decomposition_parameter: circuit_pbs_decomposition_parameter,
|
||||
output_glwe_params: pbs_parameters.output_glwe_params,
|
||||
};
|
||||
|
||||
// Hybrid packing
|
||||
let complexity_1_cmux_hp = DEFAULT_COMPLEXITY
|
||||
.pbs
|
||||
.complexity(cmux_tree_blind_rotate_parameters, ciphertext_modulus_log); // TODO: missing fft transform
|
||||
|
||||
// Hybrid packing (Do we have 1 or 2 groups)
|
||||
let log2_polynomial_size = pbs_parameters.output_glwe_params.log2_polynomial_size;
|
||||
// Size of cmux_group, can be zero
|
||||
let cmux_group_count = if global_precision > log2_polynomial_size {
|
||||
2f64.powi((global_precision - log2_polynomial_size - 1) as i32)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let complexity_cmux_tree = cmux_group_count as f64 * complexity_1_cmux_hp;
|
||||
// Hybrid packing blind rotate
|
||||
let complexity_g_br = complexity_1_cmux_hp
|
||||
* u64::min(
|
||||
pbs_parameters.output_glwe_params.log2_polynomial_size,
|
||||
global_precision,
|
||||
) as f64;
|
||||
|
||||
let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br;
|
||||
let complexity_multi_hybrid_packing = n_functions as f64 * complexity_hybrid_packing;
|
||||
|
||||
// Circuit bs: fp-ks
|
||||
let complexity_all_ppks = ((pbs_parameters.output_glwe_params.glwe_dimension + 1)
|
||||
* circuit_pbs_decomposition_parameter.level
|
||||
* precision
|
||||
* n_inputs) as f64
|
||||
* complexity_ppks;
|
||||
|
||||
// Circuit bs: pbs
|
||||
let complexity_all_pbs =
|
||||
(n_inputs * precision * circuit_pbs_decomposition_parameter.level) as f64
|
||||
* complexity_pbs;
|
||||
|
||||
let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks;
|
||||
|
||||
if complexity_bit_extract_wo_ks + complexity_circuit_bs > best_complexity {
|
||||
continue;
|
||||
}
|
||||
|
||||
let noise_ggsw = base_noise_private_packing_ks + base_noise / 2.;
|
||||
|
||||
// Circuit Boostrap
|
||||
let noise_hybrid_packing = noise_modulus_switching + noise_ggsw;
|
||||
if noise_hybrid_packing > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
let noise_one_external_product_for_cmux_tree =
|
||||
noise_atomic_pattern::variance_bootstrap::<W>(
|
||||
cmux_tree_blind_rotate_parameters,
|
||||
ciphertext_modulus_log,
|
||||
Variance::from_variance(noise_ggsw),
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
// final out noise hybrid packing
|
||||
let noise_cmux_tree_blind_rotate =
|
||||
noise_one_external_product_for_cmux_tree * (precision * n_inputs) as f64;
|
||||
|
||||
let noise_multisum = (2_f64.powf(2. * log_norm as f64)) * noise_cmux_tree_blind_rotate; // out noise * weights
|
||||
|
||||
let noise_all_multisum = noise_multisum * (1 << (2 * (precision - 1))) as f64;
|
||||
|
||||
let noise_ggsw_reencoding = noise_modulus_switching + noise_all_multisum;
|
||||
if noise_ggsw_reencoding > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
let noise_max = noise_ggsw_reencoding.max(noise_hybrid_packing);
|
||||
|
||||
// Shared by all pbs (like brs)
|
||||
for (ks_dp_index, &ks_decomposition_parameter) in KS_BL_FOR_CB.iter().enumerate() {
|
||||
let (noise_keyswitch, complexity_keyswitch) = micro_tab.key_switching[ks_dp_index];
|
||||
let noise_max = noise_max + noise_keyswitch;
|
||||
if noise_max > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
let complexity_all_ks = (precision * n_inputs) as f64 * complexity_keyswitch;
|
||||
let complexity_bit_extract = complexity_bit_extract_wo_ks + complexity_all_ks;
|
||||
|
||||
let complexity_ggsw_reencoding = complexity_bit_extract + complexity_circuit_bs;
|
||||
|
||||
let complexity = complexity_ggsw_reencoding + complexity_multi_hybrid_packing;
|
||||
|
||||
if complexity > best_complexity {
|
||||
// next ks.level will be even more costly
|
||||
break;
|
||||
}
|
||||
|
||||
if complexity < best_complexity {
|
||||
let kappa = consts.kappa;
|
||||
best_complexity = complexity;
|
||||
let p_error = find_p_error(kappa, variance_max, noise_max);
|
||||
let input_lwe_dimension = glwe_params.sample_extract_lwe_dimension();
|
||||
let glwe_polynomial_size = glwe_params.polynomial_size();
|
||||
let glwe_dimension = glwe_params.glwe_dimension;
|
||||
state.best_solution = Some(Solution {
|
||||
input_lwe_dimension,
|
||||
internal_ks_output_lwe_dimension: internal_dim,
|
||||
ks_decomposition_level_count: ks_decomposition_parameter.level,
|
||||
ks_decomposition_base_log: ks_decomposition_parameter.log2_base,
|
||||
glwe_polynomial_size,
|
||||
glwe_dimension,
|
||||
br_decomposition_level_count: br_decomposition_parameter.level,
|
||||
br_decomposition_base_log: br_decomposition_parameter.log2_base,
|
||||
noise_max,
|
||||
complexity,
|
||||
p_error,
|
||||
cb_decomposition_level_count: Some(
|
||||
circuit_pbs_decomposition_parameter.level,
|
||||
),
|
||||
cb_decomposition_base_log: Some(
|
||||
circuit_pbs_decomposition_parameter.log2_base,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const BITS_PADDING_WITHOUT_NOISE: u64 = 1;
|
||||
|
||||
#[allow(clippy::expect_fun_call)]
|
||||
#[allow(clippy::identity_op)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn optimise_one<W: UnsignedInteger>(
|
||||
precision: u64, // max precision of a word
|
||||
log_norm: f64, // ?? norm2 of noise multisum, complexity of multisum is neglected
|
||||
max_word_precision: u64, // max precision of a word
|
||||
log_norm: f64, // ?? norm2 of noise multisum, complexity of multisum is neglected
|
||||
security_level: u64,
|
||||
maximum_acceptable_error_probability: f64,
|
||||
glwe_log_polynomial_sizes: &[u64],
|
||||
@@ -213,7 +412,6 @@ pub fn optimise_one<W: UnsignedInteger>(
|
||||
assert!(maximum_acceptable_error_probability < 1.0);
|
||||
|
||||
let ciphertext_modulus_log = W::BITS as u64;
|
||||
let global_precision = n_inputs * precision;
|
||||
|
||||
// Circuit BS bound
|
||||
// 1 bit of message only here =)
|
||||
@@ -233,227 +431,40 @@ pub fn optimise_one<W: UnsignedInteger>(
|
||||
* KS_BL.len()
|
||||
* BR_BL.len(),
|
||||
};
|
||||
|
||||
let mut best_complexity = f64::INFINITY;
|
||||
let consts = OptimizationDecompositionsConsts {
|
||||
kappa,
|
||||
sum_size: 1, // Ignored
|
||||
security_level,
|
||||
noise_factor: log_norm.exp2(),
|
||||
ciphertext_modulus_log,
|
||||
keyswitch_decompositions: vec![], // to be used later
|
||||
blind_rotate_decompositions: vec![], // to be used later
|
||||
safe_variance: variance_max,
|
||||
};
|
||||
|
||||
for &glwe_dim in glwe_dimensions {
|
||||
for &glwe_log_poly_size in glwe_log_polynomial_sizes {
|
||||
let glwe_poly_size = 1 << glwe_log_poly_size;
|
||||
let input_lwe_dimension = glwe_dim << glwe_log_poly_size;
|
||||
// Manual experimental CUT
|
||||
if input_lwe_dimension > 1 << 13 {
|
||||
continue;
|
||||
}
|
||||
|
||||
if glwe_dim * glwe_poly_size <= 1 << 13 {
|
||||
// Manual experimental CUT
|
||||
let glwe_params = GlweParameters {
|
||||
log2_polynomial_size: glwe_log_poly_size,
|
||||
glwe_dimension: glwe_dim,
|
||||
};
|
||||
let glwe_params = GlweParameters {
|
||||
log2_polynomial_size: glwe_log_poly_size,
|
||||
glwe_dimension: glwe_dim,
|
||||
};
|
||||
|
||||
let input_lwe_dimension = glwe_params.sample_extract_lwe_dimension();
|
||||
|
||||
for &internal_dim in internal_lwe_dimensions {
|
||||
let micro_tab = compute_noise_cost_by_micro_param::<W>(
|
||||
security_level,
|
||||
glwe_params,
|
||||
internal_dim,
|
||||
);
|
||||
|
||||
let noise_modulus_switching =
|
||||
noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key::<W>(
|
||||
internal_dim,
|
||||
glwe_params.polynomial_size(),
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
if noise_modulus_switching > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
// BlindRotate dans Circuit BS
|
||||
for (br_dp_index, &br_decomposition_parameter) in
|
||||
BR_BL_FOR_CB.iter().enumerate()
|
||||
{
|
||||
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
|
||||
// TODO: choisir indépendemment(separate FP-KS)
|
||||
let pbs_parameters = PbsParameters {
|
||||
internal_lwe_dimension: LweDimension(internal_dim),
|
||||
br_decomposition_parameter,
|
||||
output_glwe_params: glwe_params,
|
||||
};
|
||||
|
||||
let (base_noise, complexity_pbs) = micro_tab.pbs[br_dp_index];
|
||||
|
||||
// new pbs key for the bit extract pbs, shared
|
||||
let bit_extract_dp_index = br_dp_index;
|
||||
|
||||
let (_bit_extract_base_noise, complexity_bit_extract_pbs) =
|
||||
micro_tab.pbs[bit_extract_dp_index];
|
||||
|
||||
let complexity_bit_extract_wo_ks =
|
||||
(n_inputs * (precision - 1)) as f64 * complexity_bit_extract_pbs;
|
||||
|
||||
if complexity_bit_extract_wo_ks > best_complexity {
|
||||
continue;
|
||||
}
|
||||
|
||||
// private packing keyswitch, <=> FP-KS (Circuit Boostrap)
|
||||
let pp_ks_dp_index = br_dp_index;
|
||||
|
||||
// Circuit Boostrap
|
||||
let (base_noise_private_packing_ks, complexity_ppks) =
|
||||
micro_tab.pp_switching[pp_ks_dp_index];
|
||||
|
||||
// CircuitBootstrap: new parameters l,b
|
||||
for &circuit_pbs_decomposition_parameter in CB_V1_BL.iter() {
|
||||
// Hybrid packing
|
||||
let nb_cmux = 1_u64;
|
||||
let cmux_tree_blind_rotate_parameters = PbsParameters {
|
||||
internal_lwe_dimension: LweDimension(nb_cmux), // complexity for 1 cmux
|
||||
br_decomposition_parameter: circuit_pbs_decomposition_parameter,
|
||||
output_glwe_params: pbs_parameters.output_glwe_params,
|
||||
};
|
||||
|
||||
// Hybrid packing
|
||||
let complexity_1_cmux_hp = DEFAULT_COMPLEXITY.pbs.complexity(
|
||||
cmux_tree_blind_rotate_parameters,
|
||||
ciphertext_modulus_log,
|
||||
); // TODO: missing fft transform
|
||||
|
||||
// Hybrid packing (Do we have 1 or 2 groups)
|
||||
let log2_polynomial_size =
|
||||
pbs_parameters.output_glwe_params.log2_polynomial_size;
|
||||
// Size of cmux_group, can be zero
|
||||
let cmux_group_count = if global_precision > log2_polynomial_size {
|
||||
2f64.powi((global_precision - log2_polynomial_size - 1) as i32)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let complexity_cmux_tree =
|
||||
cmux_group_count as f64 * complexity_1_cmux_hp;
|
||||
// Hybrid packing blind rotate
|
||||
let complexity_g_br = complexity_1_cmux_hp
|
||||
* u64::min(
|
||||
pbs_parameters.output_glwe_params.log2_polynomial_size,
|
||||
global_precision,
|
||||
) as f64;
|
||||
|
||||
let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br;
|
||||
let complexity_multi_hybrid_packing =
|
||||
n_functions as f64 * complexity_hybrid_packing;
|
||||
|
||||
// Circuit bs: fp-ks
|
||||
let complexity_all_ppks =
|
||||
((pbs_parameters.output_glwe_params.glwe_dimension + 1)
|
||||
* circuit_pbs_decomposition_parameter.level
|
||||
* precision
|
||||
* n_inputs) as f64
|
||||
* complexity_ppks;
|
||||
|
||||
// Circuit bs: pbs
|
||||
let complexity_all_pbs =
|
||||
(n_inputs * precision * circuit_pbs_decomposition_parameter.level)
|
||||
as f64
|
||||
* complexity_pbs;
|
||||
|
||||
let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks;
|
||||
|
||||
if complexity_bit_extract_wo_ks + complexity_circuit_bs
|
||||
> best_complexity
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let noise_ggsw = base_noise_private_packing_ks + base_noise / 2.;
|
||||
|
||||
// Circuit Boostrap
|
||||
let noise_hybrid_packing = noise_modulus_switching + noise_ggsw;
|
||||
if noise_hybrid_packing > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
let noise_one_external_product_for_cmux_tree =
|
||||
noise_atomic_pattern::variance_bootstrap::<W>(
|
||||
cmux_tree_blind_rotate_parameters,
|
||||
ciphertext_modulus_log,
|
||||
Variance::from_variance(noise_ggsw),
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
// final out noise hybrid packing
|
||||
let noise_cmux_tree_blind_rotate =
|
||||
noise_one_external_product_for_cmux_tree
|
||||
* (precision * n_inputs) as f64;
|
||||
|
||||
let noise_multisum =
|
||||
(2_f64.powf(2. * log_norm as f64)) * noise_cmux_tree_blind_rotate; // out noise * weights
|
||||
|
||||
let noise_all_multisum =
|
||||
noise_multisum * (1 << (2 * (precision - 1))) as f64;
|
||||
|
||||
let noise_ggsw_reencoding =
|
||||
noise_modulus_switching + noise_all_multisum;
|
||||
if noise_ggsw_reencoding > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
let noise_max = noise_ggsw_reencoding.max(noise_hybrid_packing);
|
||||
|
||||
// Shared by all pbs (like brs)
|
||||
for (ks_dp_index, &ks_decomposition_parameter) in
|
||||
KS_BL_FOR_CB.iter().enumerate()
|
||||
{
|
||||
let (noise_keyswitch, complexity_keyswitch) =
|
||||
micro_tab.key_switching[ks_dp_index];
|
||||
let noise_max = noise_max + noise_keyswitch;
|
||||
if noise_max > variance_max {
|
||||
continue;
|
||||
}
|
||||
|
||||
let complexity_all_ks =
|
||||
(precision * n_inputs) as f64 * complexity_keyswitch;
|
||||
let complexity_bit_extract =
|
||||
complexity_bit_extract_wo_ks + complexity_all_ks;
|
||||
|
||||
let complexity_ggsw_reencoding =
|
||||
complexity_bit_extract + complexity_circuit_bs;
|
||||
|
||||
let complexity =
|
||||
complexity_ggsw_reencoding + complexity_multi_hybrid_packing;
|
||||
|
||||
if complexity > best_complexity {
|
||||
// next ks.level will be even more costly
|
||||
break;
|
||||
}
|
||||
|
||||
if complexity < best_complexity {
|
||||
best_complexity = complexity;
|
||||
let p_error = find_p_error(kappa, variance_max, noise_max);
|
||||
state.best_solution = Some(Solution {
|
||||
input_lwe_dimension,
|
||||
internal_ks_output_lwe_dimension: internal_dim,
|
||||
ks_decomposition_level_count: ks_decomposition_parameter
|
||||
.level,
|
||||
ks_decomposition_base_log: ks_decomposition_parameter
|
||||
.log2_base,
|
||||
glwe_polynomial_size: glwe_poly_size,
|
||||
glwe_dimension: glwe_dim,
|
||||
br_decomposition_level_count: br_decomposition_parameter
|
||||
.level,
|
||||
br_decomposition_base_log: br_decomposition_parameter
|
||||
.log2_base,
|
||||
noise_max,
|
||||
complexity,
|
||||
p_error,
|
||||
cb_decomposition_level_count: Some(
|
||||
circuit_pbs_decomposition_parameter.level,
|
||||
),
|
||||
cb_decomposition_base_log: Some(
|
||||
circuit_pbs_decomposition_parameter.log2_base,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for &internal_dim in internal_lwe_dimensions {
|
||||
update_state_with_best_decompositions::<W>(
|
||||
&mut state,
|
||||
&consts,
|
||||
glwe_params,
|
||||
internal_dim,
|
||||
n_functions,
|
||||
max_word_precision,
|
||||
n_inputs,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user