chore(optimizer): fix formatting

This commit is contained in:
Mayeul@Zama
2023-05-02 15:00:51 +02:00
committed by mayeul-zama
parent 97b13e871c
commit 2e94c21970

View File

@@ -54,7 +54,8 @@ mod tests {
&SHARED_CACHES,
p_cut,
default_partition,
).map(|v| v.1)
)
.map(|v| v.1)
}
fn optimize_single(dag: &unparametrized::OperationDag) -> Option<Parameters> {
@@ -71,15 +72,12 @@ mod tests {
if sol_multi.is_none() {
return None;
}
let equiv = sol_mono.best_solution.unwrap().complexity
== sol_multi.as_ref().unwrap().complexity;
let equiv =
sol_mono.best_solution.unwrap().complexity == sol_multi.as_ref().unwrap().complexity;
if !equiv {
eprintln!("Not same complexity");
eprintln!("Single: {:?}", sol_mono.best_solution.unwrap());
eprintln!(
"Multi: {:?}",
sol_multi.clone().unwrap().complexity
);
eprintln!("Multi: {:?}", sol_multi.clone().unwrap().complexity);
eprintln!("Multi: {:?}", sol_multi.unwrap());
}
Some(equiv)
@@ -136,9 +134,7 @@ mod tests {
let feasible_2 = sol_single_2.best_solution.is_some();
let feasible_multi = sol_multi.is_some();
if (feasible_1 && feasible_2) != feasible_multi {
eprintln!(
"Not same feasibility {feasible_1} {feasible_2} {feasible_multi}"
);
eprintln!("Not same feasibility {feasible_1} {feasible_2} {feasible_multi}");
return Some(false);
}
if sol_multi.is_none() {
@@ -150,15 +146,13 @@ mod tests {
let cost_1 = sol_single_1.best_solution.unwrap().complexity;
let cost_2 = sol_single_2.best_solution.unwrap().complexity;
let cost_multi = sol_multi.complexity;
let equiv =
cost_1 + cost_2 == cost_multi
let equiv = cost_1 + cost_2 == cost_multi
&& cost_1 == sol_multi_1.complexity
&& cost_2 == sol_multi_2.complexity
&& sol_multi.micro_params.ks[0][0].unwrap().decomp ==
sol_multi_1.micro_params.ks[0][0].unwrap().decomp
&& sol_multi.micro_params.ks[1][1].unwrap().decomp ==
sol_multi_2.micro_params.ks[0][0].unwrap().decomp
;
&& sol_multi.micro_params.ks[0][0].unwrap().decomp
== sol_multi_1.micro_params.ks[0][0].unwrap().decomp
&& sol_multi.micro_params.ks[1][1].unwrap().decomp
== sol_multi_2.micro_params.ks[0][0].unwrap().decomp;
if !equiv {
eprintln!("Not same complexity");
eprintln!("Multi: {cost_multi:?}");
@@ -194,7 +188,11 @@ mod tests {
}
}
fn dag_lut_sum_of_2_partitions_2_layer(precision1: u8, precision2: u8, final_lut: bool) -> unparametrized::OperationDag {
fn dag_lut_sum_of_2_partitions_2_layer(
precision1: u8,
precision2: u8,
final_lut: bool,
) -> unparametrized::OperationDag {
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision1, Shape::number());
let input2 = dag.add_input(precision2, Shape::number());
@@ -212,12 +210,12 @@ mod tests {
#[test]
fn optimize_multi_independant_2_partitions_finally_added() {
let default_partition = 0;
let single_precision_sol : Vec<_> = (0..11).map(
|precision| {
let single_precision_sol: Vec<_> = (0..11)
.map(|precision| {
let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, false);
optimize_single(&dag)
}
).collect();
})
.collect();
for precision1 in 1..11 {
for precision2 in (precision1 + 1)..11 {
@@ -240,7 +238,10 @@ mod tests {
assert!(sol_multi.complexity < sol_2.complexity);
}
eprintln!("{:?}", sol_multi.micro_params.fks);
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity;
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2]
[default_partition]
.unwrap()
.complexity;
let sol_multi_without_fks = sol_multi.complexity - fks_complexity;
let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0;
if REAL_FAST_KS {
@@ -267,7 +268,8 @@ mod tests {
};
assert!(
sol_multi_without_fks / perfect_complexity < maximal_relative_degratdation,
"{precision1} {precision2} {} < {maximal_relative_degratdation}", sol_multi_without_fks / perfect_complexity
"{precision1} {precision2} {} < {maximal_relative_degratdation}",
sol_multi_without_fks / perfect_complexity
);
}
}
@@ -276,12 +278,12 @@ mod tests {
#[test]
fn optimize_multi_independant_2_partitions_finally_added_and_luted() {
let default_partition = 0;
let single_precision_sol : Vec<_> = (0..11).map(
|precision| {
let single_precision_sol: Vec<_> = (0..11)
.map(|precision| {
let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, true);
optimize_single(&dag)
}
).collect();
})
.collect();
for precision1 in 1..11 {
for precision2 in (precision1 + 1)..11 {
let p_cut = Some(PrecisionCut {
@@ -301,7 +303,10 @@ mod tests {
// The smallest the precision the more fks noise dominate
assert!(sol_1.complexity < sol_multi.complexity);
assert!(sol_multi.complexity < sol_2.complexity);
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity;
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2]
[default_partition]
.unwrap()
.complexity;
let sol_multi_without_fks = sol_multi.complexity - fks_complexity;
let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0;
let relative_degradation = sol_multi_without_fks / perfect_complexity;
@@ -325,7 +330,8 @@ mod tests {
};
assert!(
relative_degradation < maxim_relative_degradation,
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
"{precision1} {precision2} {}",
sol_multi_without_fks / perfect_complexity
);
}
}
@@ -362,17 +368,24 @@ mod tests {
dag
}
fn test_optimize_v3_expanded_round(precision_acc: usize, precision_tlu: usize, minimal_speedup: f64) {
fn test_optimize_v3_expanded_round(
precision_acc: usize,
precision_tlu: usize,
minimal_speedup: f64,
) {
let dag = dag_rounded_lut_2_layers(precision_acc, precision_tlu);
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
let sol_mono = solo_key::optimize::tests::optimize(&dag)
.best_solution
.unwrap();
let sol = optimize_rounded(&dag).unwrap();
let speedup = sol_mono.complexity / sol.complexity;
assert!(speedup >= minimal_speedup,
assert!(
speedup >= minimal_speedup,
"Speedup {speedup} smaller than {minimal_speedup} for {precision_acc}/{precision_tlu}"
);
let expected_ks = [
[true, true], // KS[0], KS[0->1]
[false, true],// KS[1]
[true, true], // KS[0], KS[0->1]
[false, true], // KS[1]
];
let expected_fks = [
[false, false],
@@ -390,7 +403,6 @@ mod tests {
test_optimize_v3_expanded_round(16, 8, 5.5);
} else {
test_optimize_v3_expanded_round(16, 8, 3.9);
}
}
@@ -409,10 +421,13 @@ mod tests {
let input1 = dag.add_input(16, Shape::number());
_ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 16);
let sol = optimize_rounded(&dag).unwrap();
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
let sol_mono = solo_key::optimize::tests::optimize(&dag)
.best_solution
.unwrap();
let minimal_speedup = 8.6;
let speedup = sol_mono.complexity / sol.complexity;
assert!(speedup >= minimal_speedup,
assert!(
speedup >= minimal_speedup,
"Speedup {speedup} smaller than {minimal_speedup}"
);
}
@@ -436,14 +451,13 @@ mod tests {
let rounded1 = dag.add_expanded_round(input1, 1);
let _rounded2 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, 1);
let sol = optimize_rounded(&dag).unwrap();
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
let sol_mono = solo_key::optimize::tests::optimize(&dag)
.best_solution
.unwrap();
let speedup = sol_mono.complexity / sol.complexity;
let minimal_speedup = if REAL_FAST_KS {
80.0
} else {
30.0
};
assert!(speedup >= minimal_speedup,
let minimal_speedup = if REAL_FAST_KS { 80.0 } else { 30.0 };
assert!(
speedup >= minimal_speedup,
"Speedup {speedup} smaller than {minimal_speedup}"
);
}
@@ -455,7 +469,7 @@ mod tests {
let mut dag = unparametrized::OperationDag::new();
let min_precision = 6;
let max_precision = 8;
let mut input_precisions : Vec<_> = (min_precision..=max_precision).collect();
let mut input_precisions: Vec<_> = (min_precision..=max_precision).collect();
if decreasing {
input_precisions.reverse();
}
@@ -463,9 +477,13 @@ mod tests {
for &out_precision in &input_precisions {
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
}
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *input_precisions.last().unwrap());
lut_input = dag.add_lut(
lut_input,
FunctionTable::UNKWOWN,
*input_precisions.last().unwrap(),
);
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, min_precision);
let mut p_cut = PrecisionCut { p_cut:vec![] };
let mut p_cut = PrecisionCut { p_cut: vec![] };
let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap();
assert!(sol.macro_params.len() == 1);
let mut complexity = sol.complexity;
@@ -479,19 +497,27 @@ mod tests {
eprintln!("PCUT {p_cut}");
let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap();
let nb_partitions = sol.macro_params.len();
assert!(nb_partitions == (p_cut.p_cut.len() + 1),
"bad nb partitions {} {p_cut}", sol.macro_params.len());
assert!(sol.complexity < complexity,
"{} < {complexity} {out_precision} / {max_precision}", sol.complexity);
assert!(
nb_partitions == (p_cut.p_cut.len() + 1),
"bad nb partitions {} {p_cut}",
sol.macro_params.len()
);
assert!(
sol.complexity < complexity,
"{} < {complexity} {out_precision} / {max_precision}",
sol.complexity
);
for (src, dst) in cross_partition(nb_partitions) {
let ks = sol.micro_params.ks[src][dst];
eprintln!("{} {src} {dst}", ks.is_some());
let expected_ks =
(!decreasing || src == dst + 1)
&& (decreasing || src + 1 == dst)
|| (src == dst && (src == 0 || src == nb_partitions - 1))
;
assert!(ks.is_some() == expected_ks, "{:?} {:?}", ks.is_some(), expected_ks);
let expected_ks = (!decreasing || src == dst + 1) && (decreasing || src + 1 == dst)
|| (src == dst && (src == 0 || src == nb_partitions - 1));
assert!(
ks.is_some() == expected_ks,
"{:?} {:?}",
ks.is_some(),
expected_ks
);
let fks = sol.micro_params.fks[src][dst];
assert!(fks.is_none());
}
@@ -515,16 +541,16 @@ mod tests {
// max v0 weight for each precision
1_073_741_824,
1_073_741_824, // 2**30, 1b
536_870_912, // 2**29, 2b
268_435_456, // 2**28, 3b
67_108_864, // 2**26, 4b
16_777_216, // 2**24, 5b
4_194_304, // 2**22, 6b
1_048_576, // 2**20, 7b
262_144, // 2**18, 8b
65_536, // 2**16, 9b
16384, // 2**14, 10b
2048, // 2**11, 11b
536_870_912, // 2**29, 2b
268_435_456, // 2**28, 3b
67_108_864, // 2**26, 4b
16_777_216, // 2**24, 5b
4_194_304, // 2**22, 6b
1_048_576, // 2**20, 7b
262_144, // 2**18, 8b
65_536, // 2**16, 9b
16384, // 2**14, 10b
2048, // 2**11, 11b
];
#[test]
@@ -539,7 +565,7 @@ mod tests {
let mut optimal_complexity = sol_single.as_ref().unwrap().complexity;
let mut optimal_p_error = sol_single.unwrap().p_error;
for &out_precision in &precisions[1..] {
let noise_factor = MAX_WEIGHT[out_precision] as f64;
let noise_factor = MAX_WEIGHT[out_precision] as f64;
add_v0_dag(&mut dag, sum_size, out_precision as u64, noise_factor);
let sol_single = optimize_single(&v0_dag(sum_size, out_precision as u64, noise_factor));
optimal_complexity += sol_single.as_ref().unwrap().complexity;
@@ -566,10 +592,20 @@ mod tests {
let mut lut_input = dag.add_input(precisions[0], Shape::number());
for out_precision in precisions {
let noise_factor = MAX_WEIGHT[*dag.out_precisions.last().unwrap() as usize] as f64;
lut_input = dag.add_levelled_op([lut_input], LevelledComplexity::ZERO, noise_factor, Shape::number(), "");
lut_input = dag.add_levelled_op(
[lut_input],
LevelledComplexity::ZERO,
noise_factor,
Shape::number(),
"",
);
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
}
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *precisions.last().unwrap());
_ = dag.add_lut(
lut_input,
FunctionTable::UNKWOWN,
*precisions.last().unwrap(),
);
let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution;
assert!(sol_single.is_none());
let sol = optimize(&dag, &None, 0);