diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs index 36cdf08d0..afd886a06 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs @@ -54,7 +54,8 @@ mod tests { &SHARED_CACHES, p_cut, default_partition, - ).map(|v| v.1) + ) + .map(|v| v.1) } fn optimize_single(dag: &unparametrized::OperationDag) -> Option { @@ -71,15 +72,12 @@ mod tests { if sol_multi.is_none() { return None; } - let equiv = sol_mono.best_solution.unwrap().complexity - == sol_multi.as_ref().unwrap().complexity; + let equiv = + sol_mono.best_solution.unwrap().complexity == sol_multi.as_ref().unwrap().complexity; if !equiv { eprintln!("Not same complexity"); eprintln!("Single: {:?}", sol_mono.best_solution.unwrap()); - eprintln!( - "Multi: {:?}", - sol_multi.clone().unwrap().complexity - ); + eprintln!("Multi: {:?}", sol_multi.clone().unwrap().complexity); eprintln!("Multi: {:?}", sol_multi.unwrap()); } Some(equiv) @@ -136,9 +134,7 @@ mod tests { let feasible_2 = sol_single_2.best_solution.is_some(); let feasible_multi = sol_multi.is_some(); if (feasible_1 && feasible_2) != feasible_multi { - eprintln!( - "Not same feasibility {feasible_1} {feasible_2} {feasible_multi}" - ); + eprintln!("Not same feasibility {feasible_1} {feasible_2} {feasible_multi}"); return Some(false); } if sol_multi.is_none() { @@ -150,15 +146,13 @@ mod tests { let cost_1 = sol_single_1.best_solution.unwrap().complexity; let cost_2 = sol_single_2.best_solution.unwrap().complexity; let cost_multi = sol_multi.complexity; - let equiv = - cost_1 + cost_2 == cost_multi + let equiv = cost_1 + cost_2 == cost_multi && cost_1 == sol_multi_1.complexity && cost_2 == sol_multi_2.complexity - && sol_multi.micro_params.ks[0][0].unwrap().decomp == - sol_multi_1.micro_params.ks[0][0].unwrap().decomp - && sol_multi.micro_params.ks[1][1].unwrap().decomp == - sol_multi_2.micro_params.ks[0][0].unwrap().decomp - ; + && sol_multi.micro_params.ks[0][0].unwrap().decomp + == sol_multi_1.micro_params.ks[0][0].unwrap().decomp + && sol_multi.micro_params.ks[1][1].unwrap().decomp + == sol_multi_2.micro_params.ks[0][0].unwrap().decomp; if !equiv { eprintln!("Not same complexity"); eprintln!("Multi: {cost_multi:?}"); @@ -194,7 +188,11 @@ mod tests { } } - fn dag_lut_sum_of_2_partitions_2_layer(precision1: u8, precision2: u8, final_lut: bool) -> unparametrized::OperationDag { + fn dag_lut_sum_of_2_partitions_2_layer( + precision1: u8, + precision2: u8, + final_lut: bool, + ) -> unparametrized::OperationDag { let mut dag = unparametrized::OperationDag::new(); let input1 = dag.add_input(precision1, Shape::number()); let input2 = dag.add_input(precision2, Shape::number()); @@ -212,12 +210,12 @@ mod tests { #[test] fn optimize_multi_independant_2_partitions_finally_added() { let default_partition = 0; - let single_precision_sol : Vec<_> = (0..11).map( - |precision| { + let single_precision_sol: Vec<_> = (0..11) + .map(|precision| { let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, false); optimize_single(&dag) - } - ).collect(); + }) + .collect(); for precision1 in 1..11 { for precision2 in (precision1 + 1)..11 { @@ -240,7 +238,10 @@ mod tests { assert!(sol_multi.complexity < sol_2.complexity); } eprintln!("{:?}", sol_multi.micro_params.fks); - let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity; + let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2] + [default_partition] + .unwrap() + .complexity; let sol_multi_without_fks = sol_multi.complexity - fks_complexity; let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0; if REAL_FAST_KS { @@ -267,7 +268,8 @@ mod tests { }; assert!( sol_multi_without_fks / perfect_complexity < maximal_relative_degratdation, - "{precision1} {precision2} {} < {maximal_relative_degratdation}", sol_multi_without_fks / perfect_complexity + "{precision1} {precision2} {} < {maximal_relative_degratdation}", + sol_multi_without_fks / perfect_complexity ); } } @@ -276,12 +278,12 @@ mod tests { #[test] fn optimize_multi_independant_2_partitions_finally_added_and_luted() { let default_partition = 0; - let single_precision_sol : Vec<_> = (0..11).map( - |precision| { + let single_precision_sol: Vec<_> = (0..11) + .map(|precision| { let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, true); optimize_single(&dag) - } - ).collect(); + }) + .collect(); for precision1 in 1..11 { for precision2 in (precision1 + 1)..11 { let p_cut = Some(PrecisionCut { @@ -301,7 +303,10 @@ mod tests { // The smallest the precision the more fks noise dominate assert!(sol_1.complexity < sol_multi.complexity); assert!(sol_multi.complexity < sol_2.complexity); - let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity; + let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2] + [default_partition] + .unwrap() + .complexity; let sol_multi_without_fks = sol_multi.complexity - fks_complexity; let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0; let relative_degradation = sol_multi_without_fks / perfect_complexity; @@ -325,7 +330,8 @@ mod tests { }; assert!( relative_degradation < maxim_relative_degradation, - "{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity + "{precision1} {precision2} {}", + sol_multi_without_fks / perfect_complexity ); } } @@ -362,17 +368,24 @@ mod tests { dag } - fn test_optimize_v3_expanded_round(precision_acc: usize, precision_tlu: usize, minimal_speedup: f64) { + fn test_optimize_v3_expanded_round( + precision_acc: usize, + precision_tlu: usize, + minimal_speedup: f64, + ) { let dag = dag_rounded_lut_2_layers(precision_acc, precision_tlu); - let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap(); + let sol_mono = solo_key::optimize::tests::optimize(&dag) + .best_solution + .unwrap(); let sol = optimize_rounded(&dag).unwrap(); let speedup = sol_mono.complexity / sol.complexity; - assert!(speedup >= minimal_speedup, + assert!( + speedup >= minimal_speedup, "Speedup {speedup} smaller than {minimal_speedup} for {precision_acc}/{precision_tlu}" ); let expected_ks = [ - [true, true], // KS[0], KS[0->1] - [false, true],// KS[1] + [true, true], // KS[0], KS[0->1] + [false, true], // KS[1] ]; let expected_fks = [ [false, false], @@ -390,7 +403,6 @@ mod tests { test_optimize_v3_expanded_round(16, 8, 5.5); } else { test_optimize_v3_expanded_round(16, 8, 3.9); - } } @@ -409,10 +421,13 @@ mod tests { let input1 = dag.add_input(16, Shape::number()); _ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 16); let sol = optimize_rounded(&dag).unwrap(); - let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap(); + let sol_mono = solo_key::optimize::tests::optimize(&dag) + .best_solution + .unwrap(); let minimal_speedup = 8.6; let speedup = sol_mono.complexity / sol.complexity; - assert!(speedup >= minimal_speedup, + assert!( + speedup >= minimal_speedup, "Speedup {speedup} smaller than {minimal_speedup}" ); } @@ -436,14 +451,13 @@ mod tests { let rounded1 = dag.add_expanded_round(input1, 1); let _rounded2 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, 1); let sol = optimize_rounded(&dag).unwrap(); - let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap(); + let sol_mono = solo_key::optimize::tests::optimize(&dag) + .best_solution + .unwrap(); let speedup = sol_mono.complexity / sol.complexity; - let minimal_speedup = if REAL_FAST_KS { - 80.0 - } else { - 30.0 - }; - assert!(speedup >= minimal_speedup, + let minimal_speedup = if REAL_FAST_KS { 80.0 } else { 30.0 }; + assert!( + speedup >= minimal_speedup, "Speedup {speedup} smaller than {minimal_speedup}" ); } @@ -455,7 +469,7 @@ mod tests { let mut dag = unparametrized::OperationDag::new(); let min_precision = 6; let max_precision = 8; - let mut input_precisions : Vec<_> = (min_precision..=max_precision).collect(); + let mut input_precisions: Vec<_> = (min_precision..=max_precision).collect(); if decreasing { input_precisions.reverse(); } @@ -463,9 +477,13 @@ mod tests { for &out_precision in &input_precisions { lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision); } - lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *input_precisions.last().unwrap()); + lut_input = dag.add_lut( + lut_input, + FunctionTable::UNKWOWN, + *input_precisions.last().unwrap(), + ); _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, min_precision); - let mut p_cut = PrecisionCut { p_cut:vec![] }; + let mut p_cut = PrecisionCut { p_cut: vec![] }; let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); assert!(sol.macro_params.len() == 1); let mut complexity = sol.complexity; @@ -479,19 +497,27 @@ mod tests { eprintln!("PCUT {p_cut}"); let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); let nb_partitions = sol.macro_params.len(); - assert!(nb_partitions == (p_cut.p_cut.len() + 1), - "bad nb partitions {} {p_cut}", sol.macro_params.len()); - assert!(sol.complexity < complexity, - "{} < {complexity} {out_precision} / {max_precision}", sol.complexity); + assert!( + nb_partitions == (p_cut.p_cut.len() + 1), + "bad nb partitions {} {p_cut}", + sol.macro_params.len() + ); + assert!( + sol.complexity < complexity, + "{} < {complexity} {out_precision} / {max_precision}", + sol.complexity + ); for (src, dst) in cross_partition(nb_partitions) { let ks = sol.micro_params.ks[src][dst]; eprintln!("{} {src} {dst}", ks.is_some()); - let expected_ks = - (!decreasing || src == dst + 1) - && (decreasing || src + 1 == dst) - || (src == dst && (src == 0 || src == nb_partitions - 1)) - ; - assert!(ks.is_some() == expected_ks, "{:?} {:?}", ks.is_some(), expected_ks); + let expected_ks = (!decreasing || src == dst + 1) && (decreasing || src + 1 == dst) + || (src == dst && (src == 0 || src == nb_partitions - 1)); + assert!( + ks.is_some() == expected_ks, + "{:?} {:?}", + ks.is_some(), + expected_ks + ); let fks = sol.micro_params.fks[src][dst]; assert!(fks.is_none()); } @@ -515,16 +541,16 @@ mod tests { // max v0 weight for each precision 1_073_741_824, 1_073_741_824, // 2**30, 1b - 536_870_912, // 2**29, 2b - 268_435_456, // 2**28, 3b - 67_108_864, // 2**26, 4b - 16_777_216, // 2**24, 5b - 4_194_304, // 2**22, 6b - 1_048_576, // 2**20, 7b - 262_144, // 2**18, 8b - 65_536, // 2**16, 9b - 16384, // 2**14, 10b - 2048, // 2**11, 11b + 536_870_912, // 2**29, 2b + 268_435_456, // 2**28, 3b + 67_108_864, // 2**26, 4b + 16_777_216, // 2**24, 5b + 4_194_304, // 2**22, 6b + 1_048_576, // 2**20, 7b + 262_144, // 2**18, 8b + 65_536, // 2**16, 9b + 16384, // 2**14, 10b + 2048, // 2**11, 11b ]; #[test] @@ -539,7 +565,7 @@ mod tests { let mut optimal_complexity = sol_single.as_ref().unwrap().complexity; let mut optimal_p_error = sol_single.unwrap().p_error; for &out_precision in &precisions[1..] { - let noise_factor = MAX_WEIGHT[out_precision] as f64; + let noise_factor = MAX_WEIGHT[out_precision] as f64; add_v0_dag(&mut dag, sum_size, out_precision as u64, noise_factor); let sol_single = optimize_single(&v0_dag(sum_size, out_precision as u64, noise_factor)); optimal_complexity += sol_single.as_ref().unwrap().complexity; @@ -566,10 +592,20 @@ mod tests { let mut lut_input = dag.add_input(precisions[0], Shape::number()); for out_precision in precisions { let noise_factor = MAX_WEIGHT[*dag.out_precisions.last().unwrap() as usize] as f64; - lut_input = dag.add_levelled_op([lut_input], LevelledComplexity::ZERO, noise_factor, Shape::number(), ""); + lut_input = dag.add_levelled_op( + [lut_input], + LevelledComplexity::ZERO, + noise_factor, + Shape::number(), + "", + ); lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision); } - _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *precisions.last().unwrap()); + _ = dag.add_lut( + lut_input, + FunctionTable::UNKWOWN, + *precisions.last().unwrap(), + ); let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution; assert!(sol_single.is_none()); let sol = optimize(&dag, &None, 0);