diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs index 7d515aef4..7152f0d81 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs @@ -750,10 +750,19 @@ fn optimize_macro( // copy back pbs from other partition let mut all_pbs = init_parameters.micro_params.pbs.clone(); all_pbs[partition] = Some(some_micro_params.pbs); + let mut all_fks = init_parameters.micro_params.fks.clone(); + for (dst_partition, maybe_fks) in fks_to_optimize.iter().enumerate() { + if let &Some(src_partition) = maybe_fks { + all_fks[src_partition][dst_partition] = + some_micro_params.fks[src_partition][dst_partition]; + assert!(used_conversion_keyswitch[src_partition][dst_partition]); + assert!(all_fks[src_partition][dst_partition].is_some()); + } + } let micro_params = MicroParameters { pbs: all_pbs, ks: some_micro_params.ks, - fks: some_micro_params.fks, + fks: all_fks, }; // for (i, pbs) in init_parameters.micro_params.pbs.iter().enumerate() { // if i != partition { @@ -900,6 +909,8 @@ pub fn optimize( } sanity_check( ¶ms, + &used_conversion_keyswitch, + &used_tlu_keyswitch, ciphertext_modulus_log, security_level, &feasible, @@ -945,6 +956,8 @@ fn used_conversion_keyswitch(dag: &AnalyzedDag) -> Vec> { #[allow(clippy::float_cmp)] fn sanity_check( params: &Parameters, + used_conversion_keyswitch: &[Vec], + used_tlu_keyswitch: &[Vec], ciphertext_modulus_log: u32, security_level: u64, feasible: &Feasible, @@ -980,16 +993,32 @@ fn sanity_check( let src_glwe_param = src_partition_macro.glwe_params; let src_lwe_dim = src_glwe_param.sample_extract_lwe_dimension(); if let Some(ks) = micro_params.ks[src_partition][partition] { + assert!( + used_tlu_keyswitch[src_partition][partition], + "Superflous ks[{src_partition}->{partition}]" + ); *operations.variance.ks(src_partition, partition) = ks.noise(src_lwe_dim); *operations.cost.ks(src_partition, partition) = ks.complexity(src_lwe_dim); } else { + assert!( + !used_tlu_keyswitch[src_partition][partition], + "Missing ks[{src_partition}->{partition}]" + ); *operations.variance.ks(src_partition, partition) = f64::MAX; *operations.cost.ks(src_partition, partition) = f64::MAX; } if let Some(fks) = micro_params.fks[src_partition][partition] { + assert!( + used_conversion_keyswitch[src_partition][partition], + "Superflous fks[{src_partition}->{partition}]" + ); *operations.variance.fks(src_partition, partition) = fks.noise; *operations.cost.fks(src_partition, partition) = fks.complexity; } else { + assert!( + !used_conversion_keyswitch[src_partition][partition], + "Missing fks[{src_partition}->{partition}]" + ); *operations.variance.fks(src_partition, partition) = f64::MAX; *operations.cost.fks(src_partition, partition) = f64::MAX; }