mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: x8 faster wop, better cache usage
This commit is contained in:
@@ -87,7 +87,7 @@ impl Solution {
|
||||
pub struct Tab {
|
||||
pbs: HashMap<(u64, u64, u64, u64, u64), (f64, Complexity)>,
|
||||
modulus_switching: HashMap<(u64, u64), (f64, f64)>,
|
||||
key_switching: HashMap<(u64, u64, u64, u64, u64), (f64, Complexity)>,
|
||||
key_switching: HashMap<(u64, u64, u64), Vec<(KsDecompositionParameters, (f64, Complexity))>>,
|
||||
// NEW VALUE MEMOIZED
|
||||
pp_switching: HashMap<(u64, u64, u64, u64), (f64, Complexity)>,
|
||||
}
|
||||
@@ -130,6 +130,8 @@ pub fn tabulate_circuit_bootstrap<W: UnsignedInteger>(
|
||||
for &internal_dim in internal_lwe_dimensions {
|
||||
assert!(256 < internal_dim);
|
||||
|
||||
let macro_key = (glwe_dim, glwe_poly_size, internal_dim);
|
||||
|
||||
let variance_ksk = noise_atomic_pattern::variance_ksk(
|
||||
internal_dim,
|
||||
ciphertext_modulus_log,
|
||||
@@ -180,6 +182,7 @@ pub fn tabulate_circuit_bootstrap<W: UnsignedInteger>(
|
||||
);
|
||||
}
|
||||
|
||||
let mut ks_seq = Vec::with_capacity(KS_BL_FOR_CB.len());
|
||||
for &ks_decomposition_parameter in KS_BL_FOR_CB.iter() {
|
||||
let keyswitch_parameter = KeyswitchParameters {
|
||||
input_lwe_dimension: LweDimension(glwe_poly_size * glwe_dim),
|
||||
@@ -196,17 +199,12 @@ pub fn tabulate_circuit_bootstrap<W: UnsignedInteger>(
|
||||
variance_ksk,
|
||||
)
|
||||
.get_variance();
|
||||
let _ = noise_cost_key_switching.insert(
|
||||
(
|
||||
glwe_dim,
|
||||
glwe_poly_size,
|
||||
internal_dim,
|
||||
ks_decomposition_parameter.log2_base,
|
||||
ks_decomposition_parameter.level,
|
||||
),
|
||||
ks_seq.push((
|
||||
ks_decomposition_parameter,
|
||||
(noise_keyswitch, complexity_keyswitch),
|
||||
);
|
||||
));
|
||||
}
|
||||
std::mem::drop(noise_cost_key_switching.insert(macro_key, ks_seq));
|
||||
|
||||
// let pp_ks_decomposition_parameter = pbs_parameters.br_decomposition_parameter;
|
||||
for &pp_ks_decomposition_parameter in BR_BL.iter() {
|
||||
@@ -327,6 +325,8 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
internal_dim, glwe_poly_size
|
||||
));
|
||||
|
||||
let macro_key = (glwe_dim, glwe_poly_size, internal_dim);
|
||||
|
||||
// BlindRotate dans Circuit BS
|
||||
for &br_decomposition_parameter in BR_BL_FOR_CB.iter() {
|
||||
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
|
||||
@@ -348,6 +348,42 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// new pbs key for the bit extract pbs, shared
|
||||
let bit_extract_decomposition_parameter = br_decomposition_parameter;
|
||||
|
||||
let &(_bit_extract_base_noise, complexity_bit_extract_pbs) = memo
|
||||
.pbs
|
||||
.get(&(
|
||||
glwe_dim,
|
||||
glwe_poly_size,
|
||||
internal_dim,
|
||||
bit_extract_decomposition_parameter.log2_base,
|
||||
bit_extract_decomposition_parameter.level,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// private packing keyswitch, <=> FP-KS (Circuit Boostrap)
|
||||
let pp_ks_decomposition_parameter =
|
||||
pbs_parameters.br_decomposition_parameter;
|
||||
// for &pp_ks_decomposition_parameter in PP_KS_BL.iter() { // independant params for FP-KS
|
||||
|
||||
// Circuit Boostrap
|
||||
let &(base_noise_private_packing_ks, complexity_ppks) = memo
|
||||
.pp_switching
|
||||
.get(&(
|
||||
glwe_dim,
|
||||
glwe_poly_size,
|
||||
pp_ks_decomposition_parameter.log2_base,
|
||||
pp_ks_decomposition_parameter.level,
|
||||
))
|
||||
.expect(&format!(
|
||||
"{}, {}, {}, {}",
|
||||
glwe_dim,
|
||||
glwe_poly_size,
|
||||
pp_ks_decomposition_parameter.log2_base,
|
||||
pp_ks_decomposition_parameter.level,
|
||||
));
|
||||
|
||||
// CircuitBootstrap: new parameters l,b
|
||||
for &circuit_pbs_decomposition_parameter in CB_V1_BL.iter() {
|
||||
// Hybrid packing
|
||||
@@ -392,27 +428,8 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
(precision * n_inputs) as f64,
|
||||
);
|
||||
|
||||
// private packing keyswitch, <=> FP-KS (Circuit Boostrap)
|
||||
let pp_ks_decomposition_parameter =
|
||||
pbs_parameters.br_decomposition_parameter;
|
||||
// for &pp_ks_decomposition_parameter in PP_KS_BL.iter() { // independant params for FP-KS
|
||||
// Circuit Boostrap
|
||||
let &(mut noise_private_packing_ks, complexity_ppks) = memo
|
||||
.pp_switching
|
||||
.get(&(
|
||||
glwe_dim,
|
||||
glwe_poly_size,
|
||||
pp_ks_decomposition_parameter.log2_base,
|
||||
pp_ks_decomposition_parameter.level,
|
||||
))
|
||||
.expect(&format!(
|
||||
"{}, {}, {}, {}",
|
||||
glwe_dim,
|
||||
glwe_poly_size,
|
||||
pp_ks_decomposition_parameter.log2_base,
|
||||
pp_ks_decomposition_parameter.level,
|
||||
));
|
||||
noise_private_packing_ks += base_noise / 2.;
|
||||
let noise_private_packing_ks =
|
||||
base_noise_private_packing_ks + base_noise / 2.;
|
||||
|
||||
// Circuit Boostrap
|
||||
if noise_private_packing_ks + noise_modulus_switching > variance_max
|
||||
@@ -448,23 +465,17 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
(2_f64.powf(2. * log_norm as f64)) * noise_cmux_tree_blind_rotate; // out noise * weights
|
||||
|
||||
// Shared by all pbs (like brs)
|
||||
for &ks_decomposition_parameter in KS_BL_FOR_CB.iter() {
|
||||
let key_switching_q = memo.key_switching.get(¯o_key).unwrap();
|
||||
for &(
|
||||
ks_decomposition_parameter,
|
||||
(noise_keyswitch, complexity_keyswitch),
|
||||
) in key_switching_q.iter()
|
||||
{
|
||||
let keyswitch_parameter = KeyswitchParameters {
|
||||
input_lwe_dimension: LweDimension(glwe_poly_size * glwe_dim),
|
||||
output_lwe_dimension: LweDimension(internal_dim),
|
||||
ks_decomposition_parameter,
|
||||
};
|
||||
|
||||
let &(noise_keyswitch, complexity_keyswitch) = memo
|
||||
.key_switching
|
||||
.get(&(
|
||||
glwe_dim,
|
||||
glwe_poly_size,
|
||||
internal_dim,
|
||||
ks_decomposition_parameter.log2_base,
|
||||
ks_decomposition_parameter.level,
|
||||
))
|
||||
.unwrap();
|
||||
let complexity_all_ks = precision as f64 * complexity_keyswitch;
|
||||
if noise_private_packing_ks
|
||||
+ noise_modulus_switching
|
||||
@@ -478,21 +489,6 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
continue;
|
||||
}
|
||||
|
||||
// new pbs key for the bit extract pbs, shared
|
||||
let bit_extract_decomposition_parameter =
|
||||
br_decomposition_parameter;
|
||||
|
||||
let &(_bit_extract_base_noise, complexity_bit_extract_pbs) = memo
|
||||
.pbs
|
||||
.get(&(
|
||||
glwe_dim,
|
||||
glwe_poly_size,
|
||||
internal_dim,
|
||||
bit_extract_decomposition_parameter.log2_base,
|
||||
bit_extract_decomposition_parameter.level,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// noise_multisum = dot
|
||||
let current_maximal_noise = noise_multisum
|
||||
* (1 << (2 * (precision - 1))) as f64
|
||||
@@ -514,6 +510,11 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
* (1 << n_functions_log) as f64
|
||||
+ complexity_bias;
|
||||
|
||||
if current_complexity > best_complexity {
|
||||
// next level is more costly
|
||||
break;
|
||||
}
|
||||
|
||||
if current_complexity < best_complexity
|
||||
&& current_maximal_noise < variance_max
|
||||
{
|
||||
@@ -558,9 +559,6 @@ pub fn optimise_one_with_memo<W: UnsignedInteger>(
|
||||
circuit_pbs_decomposition_parameter.log2_base,
|
||||
),
|
||||
});
|
||||
} else if current_complexity > best_complexity {
|
||||
// next level is more costly
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user